Regenerate (branch of stop) (#2157)

* add regenerate

* functional once again post rebase but quite ugly

* validated + cleaner UI

* more robust implementation for first messages

* squash

* remove parameter

* proper margin

* clarify for future programmers

* remove some logs

* self nit pick - smoother ux

* more self-nits

* stroke line cap

* rebase
This commit is contained in:
pablodanswer
2024-08-22 12:06:44 -07:00
committed by GitHub
parent 9d5db05e4b
commit 197b62aed1
15 changed files with 605 additions and 61 deletions

View File

@@ -0,0 +1,28 @@
"""Added alternate model to chat message
Revision ID: ee3f4b47fad5
Revises: 2d2304e27d8c
Create Date: 2024-08-12 00:11:50.915845
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "ee3f4b47fad5"
down_revision = "2d2304e27d8c"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_message",
sa.Column("overridden_model", sa.String(length=255), nullable=True),
)
def downgrade() -> None:
op.drop_column("chat_message", "overridden_model")

View File

@@ -36,6 +36,8 @@ def create_chat_chain(
chat_session_id: int,
db_session: Session,
prefetch_tool_calls: bool = True,
# Optional id at which we finish processing
stop_at_message_id: int | None = None,
) -> tuple[ChatMessage, list[ChatMessage]]:
"""Build the linear chain of messages without including the root message"""
mainline_messages: list[ChatMessage] = []
@@ -61,7 +63,12 @@ def create_chat_chain(
current_message: ChatMessage | None = root_message
while current_message is not None:
child_msg = current_message.latest_child_message
if not child_msg:
# Break if at the end of the chain
# or have reached the `final_id` of the submitted message
if not child_msg or (
stop_at_message_id and current_message.id == stop_at_message_id
):
break
current_message = id_to_msg.get(child_msg)

View File

@@ -351,7 +351,15 @@ def stream_chat_message_objects(
parent_message = root_message
user_message = None
if not use_existing_user_message:
if new_msg_req.regenerate:
final_msg, history_msgs = create_chat_chain(
stop_at_message_id=parent_id,
chat_session_id=chat_session_id,
db_session=db_session,
)
elif not use_existing_user_message:
# Create new message at the right place in the tree and update the parent's child pointer
# Don't commit yet until we verify the chat message chain
user_message = create_new_chat_message(
@@ -470,12 +478,18 @@ def stream_chat_message_objects(
user_message_id=user_message.id if user_message else None,
reserved_assistant_message_id=reserved_message_id,
)
overridden_model = (
new_msg_req.llm_override.model_version if new_msg_req.llm_override else None
)
# Cannot determine these without the LLM step or breaking out early
partial_response = partial(
create_new_chat_message,
chat_session_id=chat_session_id,
parent_message=final_msg,
prompt_id=prompt_id,
overridden_model=overridden_model,
# message=,
# rephrased_query=,
# token_count=,

View File

@@ -443,6 +443,7 @@ def create_new_chat_message(
tool_calls: list[ToolCall] | None = None,
commit: bool = True,
reserved_message_id: int | None = None,
overridden_model: str | None = None,
) -> ChatMessage:
if reserved_message_id is not None:
# Edit existing message
@@ -462,6 +463,7 @@ def create_new_chat_message(
existing_message.tool_calls = tool_calls if tool_calls else []
existing_message.error = error
existing_message.alternate_assistant_id = alternate_assistant_id
existing_message.overridden_model = overridden_model
new_chat_message = existing_message
else:
@@ -480,6 +482,7 @@ def create_new_chat_message(
tool_calls=tool_calls if tool_calls else [],
error=error,
alternate_assistant_id=alternate_assistant_id,
overridden_model=overridden_model,
)
db_session.add(new_chat_message)
@@ -719,6 +722,7 @@ def translate_db_message_to_chat_message_detail(
for tool_call in chat_message.tool_calls
],
alternate_assistant_id=chat_message.alternate_assistant_id,
overridden_model=chat_message.overridden_model,
)
return chat_msg_detail

View File

@@ -897,6 +897,7 @@ class ChatMessage(Base):
Integer, ForeignKey("persona.id"), nullable=True
)
overridden_model: Mapped[str | None] = mapped_column(String, nullable=True)
parent_message: Mapped[int | None] = mapped_column(Integer, nullable=True)
latest_child_message: Mapped[int | None] = mapped_column(Integer, nullable=True)
message: Mapped[str] = mapped_column(Text)

View File

@@ -339,11 +339,9 @@ def get_llm_max_tokens(
return GEN_AI_MAX_TOKENS
try:
model_obj = (
model_map.get(f"{model_provider}/{model_name}")
or model_map.get(model_name)
or model_map[model_name.split("/")[1]]
)
model_obj = model_map.get(f"{model_provider}/{model_name}")
if not model_obj:
model_obj = model_map[model_name]
if "max_input_tokens" in model_obj:
return model_obj["max_input_tokens"]

View File

@@ -98,7 +98,11 @@ class CreateChatMessageRequest(ChunkContext):
# will disable Query Rewording if specified
query_override: str | None = None
# enables additional handling to ensure that we regenerate with a given user message ID
regenerate: bool | None = None
# allows the caller to override the Persona / Prompt
# these do not persist in the chat thread details
llm_override: LLMOverride | None = None
prompt_override: PromptOverride | None = None
@@ -179,6 +183,7 @@ class ChatMessageDetail(BaseModel):
message_type: MessageType
time_sent: datetime
alternate_assistant_id: str | None
overridden_model: str | None
# Dict mapping citation number to db_doc_id
chat_session_id: int | None = None
citations: dict[int, int] | None

View File

@@ -390,10 +390,12 @@ export function ChatPage({
const [message, setMessage] = useState(
searchParams.get(SEARCH_PARAM_NAMES.USER_MESSAGE) || ""
);
const [completeMessageDetail, setCompleteMessageDetail] = useState<{
sessionId: number | null;
messageMap: Map<number, Message>;
}>({ sessionId: null, messageMap: new Map() });
const upsertToCompleteMessageMap = ({
messages,
completeMessageMapOverride,
@@ -413,6 +415,7 @@ export function ChatPage({
const frozenCompleteMessageMap =
completeMessageMapOverride || completeMessageDetail.messageMap;
const newCompleteMessageMap = structuredClone(frozenCompleteMessageMap);
if (newCompleteMessageMap.size === 0) {
const systemMessageId = messages[0].parentMessageId || SYSTEM_MESSAGE_ID;
const firstMessageId = messages[0].messageId;
@@ -471,8 +474,17 @@ export function ChatPage({
const messageHistory = buildLatestMessageChain(
completeMessageDetail.messageMap
);
const [submittedMessage, setSubmittedMessage] = useState("");
const [chatState, setChatState] = useState<ChatState>("input");
interface RegenerationState {
regenerating: boolean;
finalMessageIndex: number;
}
const [regenerationState, setRegenerationState] =
useState<RegenerationState | null>(null);
const [abortController, setAbortController] =
useState<AbortController | null>(null);
@@ -719,6 +731,8 @@ export function ChatPage({
forceSearch,
isSeededChat,
alternativeAssistantOverride = null,
modelOverRide,
regenerationRequest,
}: {
messageIdToResend?: number;
messageOverride?: string;
@@ -726,6 +740,8 @@ export function ChatPage({
forceSearch?: boolean;
isSeededChat?: boolean;
alternativeAssistantOverride?: Persona | null;
modelOverRide?: LlmOverride;
regenerationRequest?: RegenerationRequest | null;
} = {}) => {
if (chatState != "input") {
setPopup({
@@ -735,8 +751,14 @@ export function ChatPage({
return;
}
setRegenerationState(
regenerationRequest
? { regenerating: true, finalMessageIndex: messageIdToResend || 0 }
: null
);
setChatState("loading");
const controller = new AbortController();
setAbortController(controller);
@@ -770,12 +792,14 @@ export function ChatPage({
const messageToResendIndex = messageToResend
? messageHistory.indexOf(messageToResend)
: null;
if (!messageToResend && messageIdToResend !== undefined) {
setPopup({
message:
"Failed to re-send message - please refresh the page and try again.",
type: "error",
});
setRegenerationState(null);
setChatState("input");
return;
}
@@ -789,6 +813,7 @@ export function ChatPage({
messageToResendIndex !== null
? messageHistory.slice(0, messageToResendIndex)
: messageHistory;
let parentMessage =
messageToResendParent ||
(currMessageHistory.length > 0
@@ -827,8 +852,11 @@ export function ChatPage({
} = null;
try {
const mapKeys = Array.from(completeMessageDetail.messageMap.keys());
const systemMessage = Math.min(...mapKeys);
const lastSuccessfulMessageId =
getLastSuccessfulMessageId(currMessageHistory);
getLastSuccessfulMessageId(currMessageHistory) || systemMessage;
const stack = new CurrentMessageFIFO();
updateCurrentMessageFIFO(stack, {
@@ -836,7 +864,9 @@ export function ChatPage({
message: currMessage,
alternateAssistantId: currentAssistantId,
fileDescriptors: currentMessageFiles,
parentMessageId: lastSuccessfulMessageId,
parentMessageId:
regenerationRequest?.parentMessage.messageId ||
lastSuccessfulMessageId,
chatSessionId: currChatSessionId,
promptId: liveAssistant?.prompts[0]?.id || 0,
filters: buildFilters(
@@ -853,12 +883,14 @@ export function ChatPage({
.map((document) => document.db_doc_id as number),
queryOverride,
forceSearch,
regenerate: regenerationRequest !== undefined,
modelProvider:
modelOverRide?.name ||
llmOverrideManager.llmOverride.name ||
llmOverrideManager.globalDefault.name ||
undefined,
modelVersion:
modelOverRide?.modelName ||
llmOverrideManager.llmOverride.modelName ||
searchParams.get(SEARCH_PARAM_NAMES.MODEL_VERSION) ||
llmOverrideManager.globalDefault.modelName ||
@@ -900,15 +932,18 @@ export function ChatPage({
// we will use tempMessages until the regenerated message is complete
messageUpdates = [
{
messageId: user_message_id,
messageId: regenerationRequest
? regenerationRequest?.parentMessage?.messageId!
: user_message_id,
message: currMessage,
type: "user",
files: currentMessageFiles,
toolCalls: [],
parentMessageId: parentMessage?.messageId || null,
parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID,
},
];
if (parentMessage) {
if (parentMessage && !regenerationRequest) {
messageUpdates.push({
...parentMessage,
childrenMessageIds: (
@@ -934,6 +969,8 @@ export function ChatPage({
assistant_message_id,
user_message_id,
};
setRegenerationState(null);
} else {
const { user_message_id, frozenMessageMap, frozenSessionId } =
initialFetchDetails;
@@ -993,8 +1030,20 @@ export function ChatPage({
parentMessage || frozenMessageMap?.get(SYSTEM_MESSAGE_ID)!;
const updateFn = (messages: Message[]) => {
const replacementsMap = null;
upsertToCompleteMessageMap({
const replacementsMap = regenerationRequest
? new Map([
[
regenerationRequest?.parentMessage?.messageId,
regenerationRequest?.parentMessage?.messageId,
],
[
regenerationRequest?.messageId,
initialFetchDetails?.assistant_message_id,
],
] as [number, number][])
: null;
return upsertToCompleteMessageMap({
messages: messages,
replacementsMap: replacementsMap,
completeMessageMapOverride: frozenMessageMap,
@@ -1004,13 +1053,19 @@ export function ChatPage({
updateFn([
{
messageId: initialFetchDetails.user_message_id!,
messageId: regenerationRequest
? regenerationRequest?.parentMessage?.messageId!
: initialFetchDetails.user_message_id!,
message: currMessage,
type: "user",
files: currentMessageFiles,
toolCalls: [],
parentMessageId: error ? null : lastSuccessfulMessageId,
childrenMessageIds: [initialFetchDetails.assistant_message_id!],
childrenMessageIds: [
...(regenerationRequest?.parentMessage?.childrenMessageIds ||
[]),
initialFetchDetails.assistant_message_id!,
],
latestChildMessageId: initialFetchDetails.assistant_message_id,
},
{
@@ -1024,9 +1079,12 @@ export function ChatPage({
citations: finalMessage?.citations || {},
files: finalMessage?.files || aiMessageImages || [],
toolCalls: finalMessage?.tool_calls || toolCalls,
parentMessageId: initialFetchDetails.user_message_id,
parentMessageId: regenerationRequest
? regenerationRequest?.parentMessage?.messageId!
: initialFetchDetails.user_message_id,
alternateAssistantID: alternativeAssistant?.id,
stackTrace: stackTrace,
overridden_model: finalMessage?.overridden_model,
},
]);
}
@@ -1060,6 +1118,7 @@ export function ChatPage({
completeMessageMapOverride: completeMessageDetail.messageMap,
});
}
setRegenerationState(null);
setChatState("input");
if (isNewSession) {
if (finalMessage) {
@@ -1309,6 +1368,22 @@ export function ChatPage({
};
const secondsUntilExpiration = getSecondsUntilExpiration(user);
interface RegenerationRequest {
messageId: number;
parentMessage: Message;
}
function createRegenerator(regenerationRequest: RegenerationRequest) {
// Returns new function that only needs `modelOverRide` to be specified when called
return async function (modelOverRide: LlmOverride) {
return await onSubmit({
modelOverRide,
messageIdToResend: regenerationRequest.parentMessage.messageId,
regenerationRequest,
});
};
}
return (
<>
<HealthCheckBanner secondsUntilExpiration={secondsUntilExpiration} />
@@ -1494,7 +1569,7 @@ export function ChatPage({
)}
<div
className={
"mt-4 -ml-4 w-full mx-auto " +
"mt-4 -ml-4 w-full mx-auto " +
"absolute mobile:top-0 desktop:top-12 left-0" +
(hasPerformedInitialScroll ? "" : "invisible")
}
@@ -1503,10 +1578,19 @@ export function ChatPage({
const messageMap =
completeMessageDetail.messageMap;
const messageReactComponentKey = `${i}-${completeMessageDetail.sessionId}`;
const parentMessage = message.parentMessageId
? messageMap.get(message.parentMessageId)
: null;
if (
regenerationState &&
regenerationState.regenerating &&
message.messageId >
regenerationState.finalMessageIndex
) {
return <></>;
}
if (message.type === "user") {
const parentMessage = message.parentMessageId
? messageMap.get(message.parentMessageId)
: null;
return (
<div key={messageReactComponentKey}>
<HumanMessage
@@ -1514,9 +1598,6 @@ export function ChatPage({
content={message.message}
files={message.files}
messageId={message.messageId}
otherMessagesCanSwitchTo={
parentMessage?.childrenMessageIds || []
}
onEdit={(editedContent) => {
const parentMessageId =
message.parentMessageId!;
@@ -1536,6 +1617,9 @@ export function ChatPage({
messageOverride: editedContent,
});
}}
otherMessagesCanSwitchTo={
parentMessage?.childrenMessageIds || []
}
onMessageSelection={(messageId) => {
const newCompleteMessageMap = new Map(
messageMap
@@ -1576,6 +1660,15 @@ export function ChatPage({
)
: null;
if (
regenerationState &&
regenerationState.regenerating &&
// chatState == "loading" &&
message.messageId >
regenerationState.finalMessageIndex - 1
) {
return <></>;
}
return (
<div
key={messageReactComponentKey}
@@ -1586,6 +1679,33 @@ export function ChatPage({
}
>
<AIMessage
overriddenModel={message.overridden_model}
regenerate={createRegenerator({
messageId: message.messageId,
parentMessage: parentMessage!,
})}
otherMessagesCanSwitchTo={
parentMessage?.childrenMessageIds || []
}
onMessageSelection={(messageId) => {
const newCompleteMessageMap = new Map(
messageMap
);
newCompleteMessageMap.get(
message.parentMessageId!
)!.latestChildMessageId = messageId;
setCompleteMessageDetail({
sessionId:
completeMessageDetail.sessionId,
messageMap: newCompleteMessageMap,
});
setSelectedMessageForDocDisplay(
messageId
);
// set message as latest so we can edit this message
// and so it sticks around on page reload
setMessageAsLatest(messageId);
}}
isActive={messageHistory.length - 1 == i}
selectedDocuments={selectedDocuments}
toggleDocumentSelection={
@@ -1598,6 +1718,7 @@ export function ChatPage({
}
messageId={message.messageId}
content={message.message}
// content={message.message}
files={message.files}
query={
messageHistory[i]?.query || undefined
@@ -1739,6 +1860,7 @@ export function ChatPage({
}
})}
{chatState == "loading" &&
!regenerationState?.regenerating &&
messageHistory[messageHistory.length - 1]?.type !=
"user" && (
<HumanMessage
@@ -1746,6 +1868,7 @@ export function ChatPage({
content={submittedMessage}
/>
)}
{chatState == "loading" && (
<div
key={`${messageHistory.length}-${chatSessionIdRef.current}`}

View File

@@ -0,0 +1,184 @@
import { useChatContext } from "@/components/context/ChatContext";
import {
getDisplayNameForModel,
LlmOverride,
useLlmOverride,
} from "@/lib/hooks";
import {
DefaultDropdownElement,
StringOrNumberOption,
} from "@/components/Dropdown";
import { Persona } from "@/app/admin/assistants/interfaces";
import { destructureValue, getFinalLLM, structureValue } from "@/lib/llm/utils";
import { useState } from "react";
import { Hoverable } from "@/components/Hoverable";
import { Popover } from "@/components/popover/Popover";
import { FiStar } from "react-icons/fi";
import { StarFeedback } from "@/components/icons/icons";
import { IconType } from "react-icons";
export function RegenerateDropdown({
options,
selected,
onSelect,
side,
maxHeight,
alternate,
}: {
alternate?: string;
options: StringOrNumberOption[];
selected: string | null;
onSelect: (value: string | number | null) => void;
includeDefault?: boolean;
side?: "top" | "right" | "bottom" | "left";
maxHeight?: string;
}) {
const [isOpen, setIsOpen] = useState(false);
const Dropdown = (
<div
className={`
border
border
rounded-lg
flex
flex-col
mx-2
bg-background
${maxHeight || "max-h-96"}
overflow-y-auto
overscroll-contain relative`}
>
<p
className="
sticky
top-0
flex
bg-background
font-bold
px-3
text-sm
py-1.5
"
>
Pick a model
</p>
{options.map((option, ind) => {
const isSelected = option.value === selected;
return (
<DefaultDropdownElement
key={option.value}
name={getDisplayNameForModel(option.name)}
description={option.description}
onSelect={() => onSelect(option.value)}
isSelected={isSelected}
/>
);
})}
</div>
);
return (
<Popover
open={isOpen}
onOpenChange={(open) => setIsOpen(open)}
content={
<div onClick={() => setIsOpen(!isOpen)}>
{!alternate ? (
<Hoverable size={16} icon={StarFeedback as IconType} />
) : (
<Hoverable
size={16}
icon={StarFeedback as IconType}
hoverText={getDisplayNameForModel(alternate)}
/>
)}
</div>
}
popover={Dropdown}
align="start"
side={side}
sideOffset={5}
triggerMaxWidth
/>
);
}
export default function RegenerateOption({
selectedAssistant,
regenerate,
overriddenModel,
onHoverChange,
}: {
selectedAssistant: Persona;
regenerate: (modelOverRide: LlmOverride) => Promise<void>;
overriddenModel?: string;
onHoverChange: (isHovered: boolean) => void;
}) {
const llmOverrideManager = useLlmOverride();
const { llmProviders } = useChatContext();
const [_, llmName] = getFinalLLM(llmProviders, selectedAssistant, null);
const llmOptionsByProvider: {
[provider: string]: { name: string; value: string }[];
} = {};
const uniqueModelNames = new Set<string>();
llmProviders.forEach((llmProvider) => {
if (!llmOptionsByProvider[llmProvider.provider]) {
llmOptionsByProvider[llmProvider.provider] = [];
}
(llmProvider.display_model_names || llmProvider.model_names).forEach(
(modelName) => {
if (!uniqueModelNames.has(modelName)) {
uniqueModelNames.add(modelName);
llmOptionsByProvider[llmProvider.provider].push({
name: modelName,
value: structureValue(
llmProvider.name,
llmProvider.provider,
modelName
),
});
}
}
);
});
const llmOptions = Object.entries(llmOptionsByProvider).flatMap(
([provider, options]) => [...options]
);
const currentModelName =
llmOverrideManager?.llmOverride.modelName ||
(selectedAssistant
? selectedAssistant.llm_model_version_override || llmName
: llmName);
return (
<div
className="group flex items-center relative"
onMouseEnter={() => onHoverChange(true)}
onMouseLeave={() => onHoverChange(false)}
>
<RegenerateDropdown
alternate={overriddenModel}
options={llmOptions}
selected={currentModelName}
onSelect={(value) => {
const { name, provider, modelName } = destructureValue(
value as string
);
regenerate({
name: name,
provider: provider,
modelName: modelName,
});
}}
/>
</div>
);
}

View File

@@ -88,6 +88,7 @@ export interface Message {
latestChildMessageId?: number | null;
alternateAssistantID?: number | null;
stackTrace?: string | null;
overridden_model?: string;
}
export interface BackendChatSession {
@@ -116,6 +117,7 @@ export interface BackendMessage {
files: FileDescriptor[];
tool_calls: ToolCallFinalResult[];
alternate_assistant_id?: number | null;
overridden_model?: string;
}
export interface MessageResponseIDInfo {

View File

@@ -114,6 +114,7 @@ export type PacketType =
| MessageResponseIDInfo;
export async function* sendMessage({
regenerate,
message,
fileDescriptors,
parentMessageId,
@@ -131,6 +132,7 @@ export async function* sendMessage({
alternateAssistantId,
signal,
}: {
regenerate: boolean;
message: string;
fileDescriptors: FileDescriptor[];
parentMessageId: number | null;
@@ -159,6 +161,7 @@ export async function* sendMessage({
prompt_id: promptId,
search_doc_ids: documentsAreSelected ? selectedDocumentIds : null,
file_descriptors: fileDescriptors,
regenerate,
retrieval_options: !documentsAreSelected
? {
run_search:
@@ -386,13 +389,12 @@ export function getLastSuccessfulMessageId(messageHistory: Message[]) {
.reverse()
.find(
(message) =>
message.type === "assistant" &&
(message.type === "assistant" || message.type === "system") &&
message.messageId !== -1 &&
message.messageId !== null
);
return lastSuccessfulMessage ? lastSuccessfulMessage?.messageId : null;
}
export function processRawChatHistory(
rawMessages: BackendMessage[]
): Map<number, Message> {
@@ -435,6 +437,7 @@ export function processRawChatHistory(
parentMessageId: messageInfo.parent_message,
childrenMessageIds: [],
latestChildMessageId: messageInfo.latest_child_message,
overridden_model: messageInfo.overridden_model,
};
messages.set(messageInfo.message_id, message);

View File

@@ -45,9 +45,12 @@ import { Persona } from "@/app/admin/assistants/interfaces";
import { AssistantIcon } from "@/components/assistants/AssistantIcon";
import { Citation } from "@/components/search/results/Citation";
import { DocumentMetadataBlock } from "@/components/search/DocumentDisplay";
import {
DislikeFeedbackIcon,
LikeFeedbackIcon,
ThumbsUpIcon,
ThumbsDownIcon,
LikeFeedback,
DislikeFeedback,
} from "@/components/icons/icons";
import {
CustomTooltip,
@@ -59,6 +62,8 @@ import { useMouseTracking } from "./hooks";
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import GeneratingImageDisplay from "../tools/GeneratingImageDisplay";
import RegenerateOption from "../RegenerateOption";
import { LlmOverride } from "@/lib/hooks";
import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
const TOOLS_WITH_CUSTOM_HANDLING = [
@@ -110,6 +115,8 @@ function FileDisplay({
}
export const AIMessage = ({
regenerate,
overriddenModel,
shared,
isActive,
toggleDocumentSelection,
@@ -132,9 +139,13 @@ export const AIMessage = ({
handleForceSearch,
retrievalDisabled,
currentPersona,
otherMessagesCanSwitchTo,
onMessageSelection,
}: {
shared?: boolean;
isActive?: boolean;
otherMessagesCanSwitchTo?: number[];
onMessageSelection?: (messageId: number) => void;
selectedDocuments?: DanswerDocument[] | null;
toggleDocumentSelection?: () => void;
docs?: DanswerDocument[] | null;
@@ -155,6 +166,8 @@ export const AIMessage = ({
handleSearchQueryEdit?: (query: string) => void;
handleForceSearch?: () => void;
retrievalDisabled?: boolean;
overriddenModel?: string;
regenerate?: (modelOverRide: LlmOverride) => Promise<void>;
}) => {
const toolCallGenerating = toolCall && !toolCall.tool_result;
const processContent = (content: string | JSX.Element) => {
@@ -183,6 +196,7 @@ export const AIMessage = ({
};
const finalContent = processContent(content as string);
const [isRegenerateHovered, setIsRegenerateHovered] = useState(false);
const { isHovering, trackedElementRef, hoverElementRef } = useMouseTracking();
const settings = useContext(SettingsContext);
@@ -240,10 +254,19 @@ export const AIMessage = ({
});
}
const currentMessageInd = messageId
? otherMessagesCanSwitchTo?.indexOf(messageId)
: undefined;
const uniqueSources: ValidSources[] = Array.from(
new Set((docs || []).map((doc) => doc.source_type))
).slice(0, 3);
const includeMessageSwitcher =
currentMessageInd !== undefined &&
onMessageSelection &&
otherMessagesCanSwitchTo &&
otherMessagesCanSwitchTo.length > 1;
return (
<div ref={trackedElementRef} className={"py-5 px-2 lg:px-5 relative flex "}>
<div
@@ -483,59 +506,124 @@ export const AIMessage = ({
(isActive ? (
<div
className={`
flex md:flex-row gap-x-0.5 mt-1
transition-transform duration-300 ease-in-out
transform opacity-100 translate-y-0"
flex md:flex-row gap-x-0.5 mt-1
transition-transform duration-300 ease-in-out
transform opacity-100 translate-y-0"
`}
>
<TooltipGroup>
<div className="flex justify-start w-full gap-x-0.5">
{includeMessageSwitcher && (
<div className="-mx-1 mr-auto">
<MessageSwitcher
currentPage={currentMessageInd + 1}
totalPages={otherMessagesCanSwitchTo.length}
handlePrevious={() => {
onMessageSelection(
otherMessagesCanSwitchTo[
currentMessageInd - 1
]
);
}}
handleNext={() => {
onMessageSelection(
otherMessagesCanSwitchTo[
currentMessageInd + 1
]
);
}}
/>
</div>
)}
</div>
<CustomTooltip showTick line content="Copy!">
<CopyButton content={content.toString()} />
</CustomTooltip>
<CustomTooltip showTick line content="Good response!">
<HoverableIcon
icon={<LikeFeedbackIcon />}
icon={<LikeFeedback />}
onClick={() => handleFeedback("like")}
/>
</CustomTooltip>
<CustomTooltip showTick line content="Bad response!">
<HoverableIcon
icon={<DislikeFeedbackIcon />}
icon={<DislikeFeedback size={16} />}
onClick={() => handleFeedback("dislike")}
/>
</CustomTooltip>
{regenerate && (
<RegenerateOption
onHoverChange={setIsRegenerateHovered}
selectedAssistant={currentPersona!}
regenerate={regenerate}
overriddenModel={overriddenModel}
/>
)}
</TooltipGroup>
</div>
) : (
<div
ref={hoverElementRef}
className={`
absolute -bottom-4
invisible ${(isHovering || settings?.isMobile) && "!visible"}
opacity-0 ${(isHovering || settings?.isMobile) && "!opacity-100"}
absolute -bottom-5
invisible ${(isHovering || isRegenerateHovered || settings?.isMobile) && "!visible"}
opacity-0 ${(isHovering || isRegenerateHovered || settings?.isMobile) && "!opacity-100"}
translate-y-2 ${(isHovering || settings?.isMobile) && "!translate-y-0"}
transition-transform duration-300 ease-in-out
flex md:flex-row gap-x-0.5 bg-background-125/40 p-1.5 rounded-lg
flex md:flex-row gap-x-0.5 bg-background-125/40 -mx-1.5 p-1.5 rounded-lg
`}
>
<TooltipGroup>
<div className="flex justify-start w-full gap-x-0.5">
{includeMessageSwitcher && (
<div className="-mx-1 mr-auto">
<MessageSwitcher
currentPage={currentMessageInd + 1}
totalPages={otherMessagesCanSwitchTo.length}
handlePrevious={() => {
onMessageSelection(
otherMessagesCanSwitchTo[
currentMessageInd - 1
]
);
}}
handleNext={() => {
onMessageSelection(
otherMessagesCanSwitchTo[
currentMessageInd + 1
]
);
}}
/>
</div>
)}
</div>
<CustomTooltip showTick line content="Copy!">
<CopyButton content={content.toString()} />
</CustomTooltip>
<CustomTooltip showTick line content="Good response!">
<HoverableIcon
icon={<LikeFeedbackIcon />}
icon={<LikeFeedback />}
onClick={() => handleFeedback("like")}
/>
</CustomTooltip>
<CustomTooltip showTick line content="Bad response!">
<HoverableIcon
icon={<DislikeFeedbackIcon />}
icon={<DislikeFeedback size={16} />}
onClick={() => handleFeedback("dislike")}
/>
</CustomTooltip>
{regenerate && (
<RegenerateOption
selectedAssistant={currentPersona!}
regenerate={regenerate}
overriddenModel={overriddenModel}
onHoverChange={setIsRegenerateHovered}
/>
)}
</TooltipGroup>
</div>
))}

View File

@@ -320,15 +320,15 @@ export const DefaultDropdown = forwardRef<HTMLDivElement, DefaultDropdownProps>(
const Content = (
<div
className={`
flex
text-sm
bg-background
px-3
py-1.5
rounded-lg
border
border-border
cursor-pointer`}
flex
text-sm
bg-background
px-3
py-1.5
rounded-lg
border
border-border
cursor-pointer`}
>
<p className="line-clamp-1">
{selectedOption?.name ||

View File

@@ -1,4 +1,3 @@
import { IconProps } from "@tremor/react";
import { IconType } from "react-icons";
const ICON_SIZE = 15;
@@ -7,13 +6,22 @@ export const Hoverable: React.FC<{
icon: IconType;
onClick?: () => void;
size?: number;
}> = ({ icon, onClick, size = ICON_SIZE }) => {
active?: boolean;
hoverText?: string;
}> = ({ icon: Icon, active, hoverText, onClick, size = ICON_SIZE }) => {
return (
<div
className="hover:bg-hover p-1.5 rounded h-fit cursor-pointer"
className={`group relative flex items-center overflow-hidden p-1.5 h-fit rounded-md cursor-pointer transition-all duration-300 ease-in-out hover:bg-hover`}
onClick={onClick}
>
{icon({ size: size, className: "my-auto" })}
<div className="flex items-center ">
<Icon size={size} className="text-gray-600 shrink-0" />
{hoverText && (
<div className="max-w-0 leading-none whitespace-nowrap overflow-hidden transition-all duration-300 ease-in-out group-hover:max-w-xs group-hover:ml-2">
<span className="text-xs text-gray-700">{hoverText}</span>
</div>
)}
</div>
</div>
);
};

View File

@@ -755,6 +755,85 @@ export const ChevronIcon = ({
);
};
export const StarFeedback = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => {
return (
<svg
style={{ width: `${size}px`, height: `${size}px` }}
className={`w-[${size}px] h-[${size}px] ` + className}
xmlns="http://www.w3.org/2000/svg"
width="200"
height="200"
viewBox="0 0 24 24"
>
<path
fill="none"
stroke="currentColor"
strokeLinecap="round"
strokeLinejoin="round"
stroke-width="1.5"
d="m12.495 18.587l4.092 2.15a1.044 1.044 0 0 0 1.514-1.106l-.783-4.552a1.045 1.045 0 0 1 .303-.929l3.31-3.226a1.043 1.043 0 0 0-.575-1.785l-4.572-.657A1.044 1.044 0 0 1 15 7.907l-2.088-4.175a1.044 1.044 0 0 0-1.88 0L8.947 7.907a1.044 1.044 0 0 1-.783.575l-4.51.657a1.044 1.044 0 0 0-.584 1.785l3.309 3.226a1.044 1.044 0 0 1 .303.93l-.783 4.55a1.044 1.044 0 0 0 1.513 1.107l4.093-2.15a1.043 1.043 0 0 1 .991 0"
/>
</svg>
);
};
export const DislikeFeedback = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => {
return (
<svg
style={{ width: `${size}px`, height: `${size}px` }}
className={`w-[${size}px] h-[${size}px] ` + className}
xmlns="http://www.w3.org/2000/svg"
width="200"
height="200"
viewBox="0 0 24 24"
>
<g
fill="none"
stroke="currentColor"
strokeLinecap="round"
strokeLinejoin="round"
stroke-width="1.5"
>
<path d="M5.75 2.75H4.568c-.98 0-1.775.795-1.775 1.776v8.284c0 .98.795 1.775 1.775 1.775h1.184c.98 0 1.775-.794 1.775-1.775V4.526c0-.98-.795-1.776-1.775-1.776" />
<path d="m21.16 11.757l-1.42-7.101a2.368 2.368 0 0 0-2.367-1.906h-7.48a2.367 2.367 0 0 0-2.367 2.367v7.101a3.231 3.231 0 0 0 1.184 2.367l.982 5.918a.887.887 0 0 0 1.278.65l1.1-.543a3.551 3.551 0 0 0 1.87-4.048l-.496-1.965h5.396a2.368 2.368 0 0 0 2.32-2.84" />
</g>
</svg>
);
};
export const LikeFeedback = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => {
return (
<svg
style={{ width: `${size}px`, height: `${size}px` }}
className={`w-[${size}px] h-[${size}px] ` + className}
xmlns="http://www.w3.org/2000/svg"
width="200"
height="200"
viewBox="0 0 24 24"
>
<g
fill="none"
stroke="currentColor"
strokeLinecap="round"
strokeLinejoin="round"
stroke-width="1.5"
>
<path d="M5.75 9.415H4.568c-.98 0-1.775.794-1.775 1.775v8.284c0 .98.795 1.776 1.775 1.776h1.184c.98 0 1.775-.795 1.775-1.776V11.19c0-.98-.795-1.775-1.775-1.775" />
<path d="m21.16 12.243l-1.42 7.101a2.367 2.367 0 0 1-2.367 1.906h-7.48a2.367 2.367 0 0 1-2.367-2.367v-7.101A3.231 3.231 0 0 1 8.71 9.415l.982-5.918a.888.888 0 0 1 1.278-.65l1.1.544a3.55 3.55 0 0 1 1.87 4.047l-.496 1.965h5.396a2.367 2.367 0 0 1 2.32 2.84" />
</g>
</svg>
);
};
export const CheckmarkIcon = ({
size = 16,
className = defaultTailwindCSS,
@@ -2523,8 +2602,8 @@ export const SwapIcon = ({
<g
fill="none"
stroke="currentColor"
stroke-linecap="round"
stroke-linejoin="round"
strokeLinecap="round"
strokeLinejoin="round"
stroke-width="1.5"
>
<path d="M3.53 11.47v2.118a4.235 4.235 0 0 0 4.235 4.236H20.47M3.53 6.176h12.705a4.235 4.235 0 0 1 4.236 4.236v2.117" />
@@ -2550,8 +2629,8 @@ export const ClosedBookIcon = ({
<path
fill="none"
stroke="currentColor"
stroke-linecap="round"
stroke-linejoin="round"
strokeLinecap="round"
strokeLinejoin="round"
d="M12.5 13.54H3a1.5 1.5 0 0 1 0-3h8.5a1 1 0 0 0 1-1v-8a1 1 0 0 0-1-1H3A1.5 1.5 0 0 0 1.5 2v10m10-1.46v3"
/>
</svg>
@@ -2574,8 +2653,8 @@ export const PinIcon = ({
<path
fill="none"
stroke="currentColor"
stroke-linecap="round"
stroke-linejoin="round"
strokeLinecap="round"
strokeLinejoin="round"
stroke-width="1.5"
d="m17.942 6.076l2.442 2.442a1.22 1.22 0 0 1-.147 1.855l-1.757.232a1.697 1.697 0 0 0-.94.452c-.72.696-1.453 1.428-2.674 2.637c-.21.212-.358.478-.427.769l-.94 3.772a1.22 1.22 0 0 1-1.978.379l-3.04-3.052l-3.052-3.04a1.221 1.221 0 0 1 .379-1.978l3.747-.964a1.8 1.8 0 0 0 .77-.44c1.379-1.355 1.88-1.855 2.66-2.698c.233-.25.383-.565.428-.903l.232-1.783a1.221 1.221 0 0 1 1.856-.146zm-9.51 9.498L3.256 20.75"
/>
@@ -2599,8 +2678,8 @@ export const TwoRightArrowIcons = ({
<path
fill="none"
stroke="currentColor"
stroke-linecap="round"
stroke-linejoin="round"
strokeLinecap="round"
strokeLinejoin="round"
stroke-width="1.5"
d="m5.36 19l5.763-5.763a1.738 1.738 0 0 0 0-2.474L5.36 5m7 14l5.763-5.763a1.738 1.738 0 0 0 0-2.474L12.36 5"
/>