Fix LLM selection (#4078)

This commit is contained in:
Chris Weaver 2025-02-21 11:32:57 -08:00 committed by GitHub
parent ba21bacbbf
commit e1ff9086a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 172 additions and 142 deletions

View File

@ -47,6 +47,7 @@ import {
removeMessage, removeMessage,
sendMessage, sendMessage,
setMessageAsLatest, setMessageAsLatest,
updateLlmOverrideForChatSession,
updateParentChildren, updateParentChildren,
uploadFilesForChat, uploadFilesForChat,
useScrollonStream, useScrollonStream,
@ -65,7 +66,7 @@ import {
import { usePopup } from "@/components/admin/connectors/Popup"; import { usePopup } from "@/components/admin/connectors/Popup";
import { SEARCH_PARAM_NAMES, shouldSubmitOnLoad } from "./searchParams"; import { SEARCH_PARAM_NAMES, shouldSubmitOnLoad } from "./searchParams";
import { useDocumentSelection } from "./useDocumentSelection"; import { useDocumentSelection } from "./useDocumentSelection";
import { LlmOverride, useFilters, useLlmOverride } from "@/lib/hooks"; import { LlmDescriptor, useFilters, useLlmManager } from "@/lib/hooks";
import { ChatState, FeedbackType, RegenerationState } from "./types"; import { ChatState, FeedbackType, RegenerationState } from "./types";
import { DocumentResults } from "./documentSidebar/DocumentResults"; import { DocumentResults } from "./documentSidebar/DocumentResults";
import { OnyxInitializingLoader } from "@/components/OnyxInitializingLoader"; import { OnyxInitializingLoader } from "@/components/OnyxInitializingLoader";
@ -89,7 +90,11 @@ import {
import { buildFilters } from "@/lib/search/utils"; import { buildFilters } from "@/lib/search/utils";
import { SettingsContext } from "@/components/settings/SettingsProvider"; import { SettingsContext } from "@/components/settings/SettingsProvider";
import Dropzone from "react-dropzone"; import Dropzone from "react-dropzone";
import { checkLLMSupportsImageInput, getFinalLLM } from "@/lib/llm/utils"; import {
checkLLMSupportsImageInput,
getFinalLLM,
structureValue,
} from "@/lib/llm/utils";
import { ChatInputBar } from "./input/ChatInputBar"; import { ChatInputBar } from "./input/ChatInputBar";
import { useChatContext } from "@/components/context/ChatContext"; import { useChatContext } from "@/components/context/ChatContext";
import { v4 as uuidv4 } from "uuid"; import { v4 as uuidv4 } from "uuid";
@ -356,7 +361,7 @@ export function ChatPage({
] ]
); );
const llmOverrideManager = useLlmOverride( const llmManager = useLlmManager(
llmProviders, llmProviders,
selectedChatSession, selectedChatSession,
liveAssistant liveAssistant
@ -1138,7 +1143,7 @@ export function ChatPage({
forceSearch, forceSearch,
isSeededChat, isSeededChat,
alternativeAssistantOverride = null, alternativeAssistantOverride = null,
modelOverRide, modelOverride,
regenerationRequest, regenerationRequest,
overrideFileDescriptors, overrideFileDescriptors,
}: { }: {
@ -1148,7 +1153,7 @@ export function ChatPage({
forceSearch?: boolean; forceSearch?: boolean;
isSeededChat?: boolean; isSeededChat?: boolean;
alternativeAssistantOverride?: Persona | null; alternativeAssistantOverride?: Persona | null;
modelOverRide?: LlmOverride; modelOverride?: LlmDescriptor;
regenerationRequest?: RegenerationRequest | null; regenerationRequest?: RegenerationRequest | null;
overrideFileDescriptors?: FileDescriptor[]; overrideFileDescriptors?: FileDescriptor[];
} = {}) => { } = {}) => {
@ -1191,6 +1196,22 @@ export function ChatPage({
currChatSessionId = chatSessionIdRef.current as string; currChatSessionId = chatSessionIdRef.current as string;
} }
frozenSessionId = currChatSessionId; frozenSessionId = currChatSessionId;
// update the selected model for the chat session if one is specified so that
// it persists across page reloads. Do not `await` here so that the message
// request can continue and this will just happen in the background.
// NOTE: only set the model override for the chat session once we send a
// message with it. If the user switches models and then starts a new
// chat session, it is unexpected for that model to be used when they
// return to this session the next day.
let finalLLM = modelOverride || llmManager.currentLlm;
updateLlmOverrideForChatSession(
currChatSessionId,
structureValue(
finalLLM.name || "",
finalLLM.provider || "",
finalLLM.modelName || ""
)
);
updateStatesWithNewSessionId(currChatSessionId); updateStatesWithNewSessionId(currChatSessionId);
@ -1250,11 +1271,14 @@ export function ChatPage({
: null) || : null) ||
(messageMap.size === 1 ? Array.from(messageMap.values())[0] : null); (messageMap.size === 1 ? Array.from(messageMap.values())[0] : null);
const currentAssistantId = alternativeAssistantOverride let currentAssistantId;
? alternativeAssistantOverride.id if (alternativeAssistantOverride) {
: alternativeAssistant currentAssistantId = alternativeAssistantOverride.id;
? alternativeAssistant.id } else if (alternativeAssistant) {
: liveAssistant.id; currentAssistantId = alternativeAssistant.id;
} else {
currentAssistantId = liveAssistant.id;
}
resetInputBar(); resetInputBar();
let messageUpdates: Message[] | null = null; let messageUpdates: Message[] | null = null;
@ -1326,15 +1350,13 @@ export function ChatPage({
forceSearch, forceSearch,
regenerate: regenerationRequest !== undefined, regenerate: regenerationRequest !== undefined,
modelProvider: modelProvider:
modelOverRide?.name || modelOverride?.name || llmManager.currentLlm.name || undefined,
llmOverrideManager.llmOverride.name ||
undefined,
modelVersion: modelVersion:
modelOverRide?.modelName || modelOverride?.modelName ||
llmOverrideManager.llmOverride.modelName || llmManager.currentLlm.modelName ||
searchParams.get(SEARCH_PARAM_NAMES.MODEL_VERSION) || searchParams.get(SEARCH_PARAM_NAMES.MODEL_VERSION) ||
undefined, undefined,
temperature: llmOverrideManager.temperature || undefined, temperature: llmManager.temperature || undefined,
systemPromptOverride: systemPromptOverride:
searchParams.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined, searchParams.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined,
useExistingUserMessage: isSeededChat, useExistingUserMessage: isSeededChat,
@ -1802,7 +1824,7 @@ export function ChatPage({
const [_, llmModel] = getFinalLLM( const [_, llmModel] = getFinalLLM(
llmProviders, llmProviders,
liveAssistant, liveAssistant,
llmOverrideManager.llmOverride llmManager.currentLlm
); );
const llmAcceptsImages = checkLLMSupportsImageInput(llmModel); const llmAcceptsImages = checkLLMSupportsImageInput(llmModel);
@ -2121,7 +2143,7 @@ export function ChatPage({
}, [searchParams, router]); }, [searchParams, router]);
useEffect(() => { useEffect(() => {
llmOverrideManager.updateImageFilesPresent(imageFileInMessageHistory); llmManager.updateImageFilesPresent(imageFileInMessageHistory);
}, [imageFileInMessageHistory]); }, [imageFileInMessageHistory]);
const pathname = usePathname(); const pathname = usePathname();
@ -2175,9 +2197,9 @@ export function ChatPage({
function createRegenerator(regenerationRequest: RegenerationRequest) { function createRegenerator(regenerationRequest: RegenerationRequest) {
// Returns new function that only needs `modelOverRide` to be specified when called // Returns new function that only needs `modelOverRide` to be specified when called
return async function (modelOverRide: LlmOverride) { return async function (modelOverride: LlmDescriptor) {
return await onSubmit({ return await onSubmit({
modelOverRide, modelOverride,
messageIdToResend: regenerationRequest.parentMessage.messageId, messageIdToResend: regenerationRequest.parentMessage.messageId,
regenerationRequest, regenerationRequest,
forceSearch: regenerationRequest.forceSearch, forceSearch: regenerationRequest.forceSearch,
@ -2258,9 +2280,7 @@ export function ChatPage({
{(settingsToggled || userSettingsToggled) && ( {(settingsToggled || userSettingsToggled) && (
<UserSettingsModal <UserSettingsModal
setPopup={setPopup} setPopup={setPopup}
setLlmOverride={(newOverride) => setCurrentLlm={(newLlm) => llmManager.updateCurrentLlm(newLlm)}
llmOverrideManager.updateLLMOverride(newOverride)
}
defaultModel={user?.preferences.default_model!} defaultModel={user?.preferences.default_model!}
llmProviders={llmProviders} llmProviders={llmProviders}
onClose={() => { onClose={() => {
@ -2324,7 +2344,7 @@ export function ChatPage({
<ShareChatSessionModal <ShareChatSessionModal
assistantId={liveAssistant?.id} assistantId={liveAssistant?.id}
message={message} message={message}
modelOverride={llmOverrideManager.llmOverride} modelOverride={llmManager.currentLlm}
chatSessionId={sharedChatSession.id} chatSessionId={sharedChatSession.id}
existingSharedStatus={sharedChatSession.shared_status} existingSharedStatus={sharedChatSession.shared_status}
onClose={() => setSharedChatSession(null)} onClose={() => setSharedChatSession(null)}
@ -2342,7 +2362,7 @@ export function ChatPage({
<ShareChatSessionModal <ShareChatSessionModal
message={message} message={message}
assistantId={liveAssistant?.id} assistantId={liveAssistant?.id}
modelOverride={llmOverrideManager.llmOverride} modelOverride={llmManager.currentLlm}
chatSessionId={chatSessionIdRef.current} chatSessionId={chatSessionIdRef.current}
existingSharedStatus={chatSessionSharedStatus} existingSharedStatus={chatSessionSharedStatus}
onClose={() => setSharingModalVisible(false)} onClose={() => setSharingModalVisible(false)}
@ -3058,7 +3078,7 @@ export function ChatPage({
messageId: message.messageId, messageId: message.messageId,
parentMessage: parentMessage!, parentMessage: parentMessage!,
forceSearch: true, forceSearch: true,
})(llmOverrideManager.llmOverride); })(llmManager.currentLlm);
} else { } else {
setPopup({ setPopup({
type: "error", type: "error",
@ -3203,7 +3223,7 @@ export function ChatPage({
availableDocumentSets={documentSets} availableDocumentSets={documentSets}
availableTags={tags} availableTags={tags}
filterManager={filterManager} filterManager={filterManager}
llmOverrideManager={llmOverrideManager} llmManager={llmManager}
removeDocs={() => { removeDocs={() => {
clearSelectedDocuments(); clearSelectedDocuments();
}} }}

View File

@ -1,8 +1,8 @@
import { useChatContext } from "@/components/context/ChatContext"; import { useChatContext } from "@/components/context/ChatContext";
import { import {
getDisplayNameForModel, getDisplayNameForModel,
LlmOverride, LlmDescriptor,
useLlmOverride, useLlmManager,
} from "@/lib/hooks"; } from "@/lib/hooks";
import { StringOrNumberOption } from "@/components/Dropdown"; import { StringOrNumberOption } from "@/components/Dropdown";
@ -106,13 +106,13 @@ export default function RegenerateOption({
onDropdownVisibleChange, onDropdownVisibleChange,
}: { }: {
selectedAssistant: Persona; selectedAssistant: Persona;
regenerate: (modelOverRide: LlmOverride) => Promise<void>; regenerate: (modelOverRide: LlmDescriptor) => Promise<void>;
overriddenModel?: string; overriddenModel?: string;
onHoverChange: (isHovered: boolean) => void; onHoverChange: (isHovered: boolean) => void;
onDropdownVisibleChange: (isVisible: boolean) => void; onDropdownVisibleChange: (isVisible: boolean) => void;
}) { }) {
const { llmProviders } = useChatContext(); const { llmProviders } = useChatContext();
const llmOverrideManager = useLlmOverride(llmProviders); const llmManager = useLlmManager(llmProviders);
const [_, llmName] = getFinalLLM(llmProviders, selectedAssistant, null); const [_, llmName] = getFinalLLM(llmProviders, selectedAssistant, null);
@ -148,7 +148,7 @@ export default function RegenerateOption({
); );
const currentModelName = const currentModelName =
llmOverrideManager?.llmOverride.modelName || llmManager?.currentLlm.modelName ||
(selectedAssistant (selectedAssistant
? selectedAssistant.llm_model_version_override || llmName ? selectedAssistant.llm_model_version_override || llmName
: llmName); : llmName);

View File

@ -6,7 +6,7 @@ import { Persona } from "@/app/admin/assistants/interfaces";
import LLMPopover from "./LLMPopover"; import LLMPopover from "./LLMPopover";
import { InputPrompt } from "@/app/chat/interfaces"; import { InputPrompt } from "@/app/chat/interfaces";
import { FilterManager, LlmOverrideManager } from "@/lib/hooks"; import { FilterManager, LlmManager } from "@/lib/hooks";
import { useChatContext } from "@/components/context/ChatContext"; import { useChatContext } from "@/components/context/ChatContext";
import { ChatFileType, FileDescriptor } from "../interfaces"; import { ChatFileType, FileDescriptor } from "../interfaces";
import { import {
@ -180,7 +180,7 @@ interface ChatInputBarProps {
setMessage: (message: string) => void; setMessage: (message: string) => void;
stopGenerating: () => void; stopGenerating: () => void;
onSubmit: () => void; onSubmit: () => void;
llmOverrideManager: LlmOverrideManager; llmManager: LlmManager;
chatState: ChatState; chatState: ChatState;
alternativeAssistant: Persona | null; alternativeAssistant: Persona | null;
// assistants // assistants
@ -225,7 +225,7 @@ export function ChatInputBar({
availableSources, availableSources,
availableDocumentSets, availableDocumentSets,
availableTags, availableTags,
llmOverrideManager, llmManager,
proSearchEnabled, proSearchEnabled,
setProSearchEnabled, setProSearchEnabled,
}: ChatInputBarProps) { }: ChatInputBarProps) {
@ -781,7 +781,7 @@ export function ChatInputBar({
<LLMPopover <LLMPopover
llmProviders={llmProviders} llmProviders={llmProviders}
llmOverrideManager={llmOverrideManager} llmManager={llmManager}
requiresImageGeneration={false} requiresImageGeneration={false}
currentAssistant={selectedAssistant} currentAssistant={selectedAssistant}
/> />

View File

@ -16,7 +16,7 @@ import {
LLMProviderDescriptor, LLMProviderDescriptor,
} from "@/app/admin/configuration/llm/interfaces"; } from "@/app/admin/configuration/llm/interfaces";
import { Persona } from "@/app/admin/assistants/interfaces"; import { Persona } from "@/app/admin/assistants/interfaces";
import { LlmOverrideManager } from "@/lib/hooks"; import { LlmManager } from "@/lib/hooks";
import { import {
Tooltip, Tooltip,
@ -31,21 +31,19 @@ import { useUser } from "@/components/user/UserProvider";
interface LLMPopoverProps { interface LLMPopoverProps {
llmProviders: LLMProviderDescriptor[]; llmProviders: LLMProviderDescriptor[];
llmOverrideManager: LlmOverrideManager; llmManager: LlmManager;
requiresImageGeneration?: boolean; requiresImageGeneration?: boolean;
currentAssistant?: Persona; currentAssistant?: Persona;
} }
export default function LLMPopover({ export default function LLMPopover({
llmProviders, llmProviders,
llmOverrideManager, llmManager,
requiresImageGeneration, requiresImageGeneration,
currentAssistant, currentAssistant,
}: LLMPopoverProps) { }: LLMPopoverProps) {
const [isOpen, setIsOpen] = useState(false); const [isOpen, setIsOpen] = useState(false);
const { user } = useUser(); const { user } = useUser();
const { llmOverride, updateLLMOverride } = llmOverrideManager;
const currentLlm = llmOverride.modelName;
const llmOptionsByProvider: { const llmOptionsByProvider: {
[provider: string]: { [provider: string]: {
@ -93,19 +91,19 @@ export default function LLMPopover({
: null; : null;
const [localTemperature, setLocalTemperature] = useState( const [localTemperature, setLocalTemperature] = useState(
llmOverrideManager.temperature ?? 0.5 llmManager.temperature ?? 0.5
); );
useEffect(() => { useEffect(() => {
setLocalTemperature(llmOverrideManager.temperature ?? 0.5); setLocalTemperature(llmManager.temperature ?? 0.5);
}, [llmOverrideManager.temperature]); }, [llmManager.temperature]);
const handleTemperatureChange = (value: number[]) => { const handleTemperatureChange = (value: number[]) => {
setLocalTemperature(value[0]); setLocalTemperature(value[0]);
}; };
const handleTemperatureChangeComplete = (value: number[]) => { const handleTemperatureChangeComplete = (value: number[]) => {
llmOverrideManager.updateTemperature(value[0]); llmManager.updateTemperature(value[0]);
}; };
return ( return (
@ -120,15 +118,15 @@ export default function LLMPopover({
toggle toggle
flexPriority="stiff" flexPriority="stiff"
name={getDisplayNameForModel( name={getDisplayNameForModel(
llmOverrideManager?.llmOverride.modelName || llmManager?.currentLlm.modelName ||
defaultModelDisplayName || defaultModelDisplayName ||
"Models" "Models"
)} )}
Icon={getProviderIcon( Icon={getProviderIcon(
llmOverrideManager?.llmOverride.provider || llmManager?.currentLlm.provider ||
defaultProvider?.provider || defaultProvider?.provider ||
"anthropic", "anthropic",
llmOverrideManager?.llmOverride.modelName || llmManager?.currentLlm.modelName ||
defaultProvider?.default_model_name || defaultProvider?.default_model_name ||
"claude-3-5-sonnet-20240620" "claude-3-5-sonnet-20240620"
)} )}
@ -147,12 +145,12 @@ export default function LLMPopover({
<button <button
key={index} key={index}
className={`w-full flex items-center gap-x-2 px-3 py-2 text-sm text-left hover:bg-background-100 dark:hover:bg-neutral-800 transition-colors duration-150 ${ className={`w-full flex items-center gap-x-2 px-3 py-2 text-sm text-left hover:bg-background-100 dark:hover:bg-neutral-800 transition-colors duration-150 ${
currentLlm === name llmManager.currentLlm.modelName === name
? "bg-background-100 dark:bg-neutral-900 text-text" ? "bg-background-100 dark:bg-neutral-900 text-text"
: "text-text-darker" : "text-text-darker"
}`} }`}
onClick={() => { onClick={() => {
updateLLMOverride(destructureValue(value)); llmManager.updateCurrentLlm(destructureValue(value));
setIsOpen(false); setIsOpen(false);
}} }}
> >
@ -172,7 +170,7 @@ export default function LLMPopover({
); );
} }
})()} })()}
{llmOverrideManager.imageFilesPresent && {llmManager.imageFilesPresent &&
!checkLLMSupportsImageInput(name) && ( !checkLLMSupportsImageInput(name) && (
<TooltipProvider> <TooltipProvider>
<Tooltip delayDuration={0}> <Tooltip delayDuration={0}>
@ -199,7 +197,7 @@ export default function LLMPopover({
<div className="w-full px-3 py-2"> <div className="w-full px-3 py-2">
<Slider <Slider
value={[localTemperature]} value={[localTemperature]}
max={llmOverrideManager.maxTemperature} max={llmManager.maxTemperature}
min={0} min={0}
step={0.01} step={0.01}
onValueChange={handleTemperatureChange} onValueChange={handleTemperatureChange}

View File

@ -65,7 +65,7 @@ export function getChatRetentionInfo(
}; };
} }
export async function updateModelOverrideForChatSession( export async function updateLlmOverrideForChatSession(
chatSessionId: string, chatSessionId: string,
newAlternateModel: string newAlternateModel: string
) { ) {

View File

@ -44,7 +44,7 @@ import { ValidSources } from "@/lib/types";
import { useMouseTracking } from "./hooks"; import { useMouseTracking } from "./hooks";
import { SettingsContext } from "@/components/settings/SettingsProvider"; import { SettingsContext } from "@/components/settings/SettingsProvider";
import RegenerateOption from "../RegenerateOption"; import RegenerateOption from "../RegenerateOption";
import { LlmOverride } from "@/lib/hooks"; import { LlmDescriptor } from "@/lib/hooks";
import { ContinueGenerating } from "./ContinueMessage"; import { ContinueGenerating } from "./ContinueMessage";
import { MemoizedAnchor, MemoizedParagraph } from "./MemoizedTextComponents"; import { MemoizedAnchor, MemoizedParagraph } from "./MemoizedTextComponents";
import { extractCodeText, preprocessLaTeX } from "./codeUtils"; import { extractCodeText, preprocessLaTeX } from "./codeUtils";
@ -117,7 +117,7 @@ export const AgenticMessage = ({
isComplete?: boolean; isComplete?: boolean;
handleFeedback?: (feedbackType: FeedbackType) => void; handleFeedback?: (feedbackType: FeedbackType) => void;
overriddenModel?: string; overriddenModel?: string;
regenerate?: (modelOverRide: LlmOverride) => Promise<void>; regenerate?: (modelOverRide: LlmDescriptor) => Promise<void>;
setPresentingDocument?: (document: OnyxDocument) => void; setPresentingDocument?: (document: OnyxDocument) => void;
toggleDocDisplay?: (agentic: boolean) => void; toggleDocDisplay?: (agentic: boolean) => void;
error?: string | null; error?: string | null;

View File

@ -58,7 +58,7 @@ import { useMouseTracking } from "./hooks";
import { SettingsContext } from "@/components/settings/SettingsProvider"; import { SettingsContext } from "@/components/settings/SettingsProvider";
import GeneratingImageDisplay from "../tools/GeneratingImageDisplay"; import GeneratingImageDisplay from "../tools/GeneratingImageDisplay";
import RegenerateOption from "../RegenerateOption"; import RegenerateOption from "../RegenerateOption";
import { LlmOverride } from "@/lib/hooks"; import { LlmDescriptor } from "@/lib/hooks";
import { ContinueGenerating } from "./ContinueMessage"; import { ContinueGenerating } from "./ContinueMessage";
import { MemoizedAnchor, MemoizedParagraph } from "./MemoizedTextComponents"; import { MemoizedAnchor, MemoizedParagraph } from "./MemoizedTextComponents";
import { extractCodeText, preprocessLaTeX } from "./codeUtils"; import { extractCodeText, preprocessLaTeX } from "./codeUtils";
@ -213,7 +213,7 @@ export const AIMessage = ({
handleForceSearch?: () => void; handleForceSearch?: () => void;
retrievalDisabled?: boolean; retrievalDisabled?: boolean;
overriddenModel?: string; overriddenModel?: string;
regenerate?: (modelOverRide: LlmOverride) => Promise<void>; regenerate?: (modelOverRide: LlmDescriptor) => Promise<void>;
setPresentingDocument: (document: OnyxDocument) => void; setPresentingDocument: (document: OnyxDocument) => void;
removePadding?: boolean; removePadding?: boolean;
}) => { }) => {

View File

@ -11,7 +11,7 @@ import { CopyButton } from "@/components/CopyButton";
import { SEARCH_PARAM_NAMES } from "../searchParams"; import { SEARCH_PARAM_NAMES } from "../searchParams";
import { usePopup } from "@/components/admin/connectors/Popup"; import { usePopup } from "@/components/admin/connectors/Popup";
import { structureValue } from "@/lib/llm/utils"; import { structureValue } from "@/lib/llm/utils";
import { LlmOverride } from "@/lib/hooks"; import { LlmDescriptor } from "@/lib/hooks";
import { Separator } from "@/components/ui/separator"; import { Separator } from "@/components/ui/separator";
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle"; import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
@ -38,7 +38,7 @@ async function generateShareLink(chatSessionId: string) {
async function generateSeedLink( async function generateSeedLink(
message?: string, message?: string,
assistantId?: number, assistantId?: number,
modelOverride?: LlmOverride modelOverride?: LlmDescriptor
) { ) {
const baseUrl = `${window.location.protocol}//${window.location.host}`; const baseUrl = `${window.location.protocol}//${window.location.host}`;
const model = modelOverride const model = modelOverride
@ -92,7 +92,7 @@ export function ShareChatSessionModal({
onClose: () => void; onClose: () => void;
message?: string; message?: string;
assistantId?: number; assistantId?: number;
modelOverride?: LlmOverride; modelOverride?: LlmDescriptor;
}) { }) {
const [shareLink, setShareLink] = useState<string>( const [shareLink, setShareLink] = useState<string>(
existingSharedStatus === ChatSessionSharedStatus.Public existingSharedStatus === ChatSessionSharedStatus.Public

View File

@ -1,6 +1,6 @@
import { useContext, useEffect, useRef, useState } from "react"; import { useContext, useEffect, useRef, useState } from "react";
import { Modal } from "@/components/Modal"; import { Modal } from "@/components/Modal";
import { getDisplayNameForModel, LlmOverride } from "@/lib/hooks"; import { getDisplayNameForModel, LlmDescriptor } from "@/lib/hooks";
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces"; import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
import { destructureValue, structureValue } from "@/lib/llm/utils"; import { destructureValue, structureValue } from "@/lib/llm/utils";
@ -31,12 +31,12 @@ export function UserSettingsModal({
setPopup, setPopup,
llmProviders, llmProviders,
onClose, onClose,
setLlmOverride, setCurrentLlm,
defaultModel, defaultModel,
}: { }: {
setPopup: (popupSpec: PopupSpec | null) => void; setPopup: (popupSpec: PopupSpec | null) => void;
llmProviders: LLMProviderDescriptor[]; llmProviders: LLMProviderDescriptor[];
setLlmOverride?: (newOverride: LlmOverride) => void; setCurrentLlm?: (newLlm: LlmDescriptor) => void;
onClose: () => void; onClose: () => void;
defaultModel: string | null; defaultModel: string | null;
}) { }) {
@ -127,18 +127,14 @@ export function UserSettingsModal({
); );
}); });
const llmOptions = Object.entries(llmOptionsByProvider).flatMap(
([provider, options]) => [...options]
);
const router = useRouter(); const router = useRouter();
const handleChangedefaultModel = async (defaultModel: string | null) => { const handleChangedefaultModel = async (defaultModel: string | null) => {
try { try {
const response = await setUserDefaultModel(defaultModel); const response = await setUserDefaultModel(defaultModel);
if (response.ok) { if (response.ok) {
if (defaultModel && setLlmOverride) { if (defaultModel && setCurrentLlm) {
setLlmOverride(destructureValue(defaultModel)); setCurrentLlm(destructureValue(defaultModel));
} }
setPopup({ setPopup({
message: "Default model updated successfully", message: "Default model updated successfully",

View File

@ -360,18 +360,18 @@ export const useUsers = ({ includeApiKeys }: UseUsersParams) => {
}; };
}; };
export interface LlmOverride { export interface LlmDescriptor {
name: string; name: string;
provider: string; provider: string;
modelName: string; modelName: string;
} }
export interface LlmOverrideManager { export interface LlmManager {
llmOverride: LlmOverride; currentLlm: LlmDescriptor;
updateLLMOverride: (newOverride: LlmOverride) => void; updateCurrentLlm: (newOverride: LlmDescriptor) => void;
temperature: number; temperature: number;
updateTemperature: (temperature: number) => void; updateTemperature: (temperature: number) => void;
updateModelOverrideForChatSession: (chatSession?: ChatSession) => void; updateModelOverrideBasedOnChatSession: (chatSession?: ChatSession) => void;
imageFilesPresent: boolean; imageFilesPresent: boolean;
updateImageFilesPresent: (present: boolean) => void; updateImageFilesPresent: (present: boolean) => void;
liveAssistant: Persona | null; liveAssistant: Persona | null;
@ -400,7 +400,7 @@ Thus, the input should be
Changes take place as Changes take place as
- liveAssistant or currentChatSession changes (and the associated model override is set) - liveAssistant or currentChatSession changes (and the associated model override is set)
- (uploadLLMOverride) User explicitly setting a model override (and we explicitly override and set the userSpecifiedOverride which we'll use in place of the user preferences unless overridden by an assistant) - (updateCurrentLlm) User explicitly setting a model override (and we explicitly override and set the userSpecifiedOverride which we'll use in place of the user preferences unless overridden by an assistant)
If we have a live assistant, we should use that model override If we have a live assistant, we should use that model override
@ -419,55 +419,78 @@ This approach ensures that user preferences are maintained for existing chats wh
providing appropriate defaults for new conversations based on the available tools. providing appropriate defaults for new conversations based on the available tools.
*/ */
export function useLlmOverride( export function useLlmManager(
llmProviders: LLMProviderDescriptor[], llmProviders: LLMProviderDescriptor[],
currentChatSession?: ChatSession, currentChatSession?: ChatSession,
liveAssistant?: Persona liveAssistant?: Persona
): LlmOverrideManager { ): LlmManager {
const { user } = useUser(); const { user } = useUser();
const [userHasManuallyOverriddenLLM, setUserHasManuallyOverriddenLLM] =
useState(false);
const [chatSession, setChatSession] = useState<ChatSession | null>(null); const [chatSession, setChatSession] = useState<ChatSession | null>(null);
const [currentLlm, setCurrentLlm] = useState<LlmDescriptor>({
name: "",
provider: "",
modelName: "",
});
const llmOverrideUpdate = () => { const llmUpdate = () => {
if (liveAssistant?.llm_model_version_override) { /* Should be called when the live assistant or current chat session changes */
setLlmOverride(
getValidLlmOverride(liveAssistant.llm_model_version_override) // separate function so we can `return` to break out
); const _llmUpdate = () => {
} else if (currentChatSession?.current_alternate_model) { // if the user has overridden in this session and just switched to a brand
setLlmOverride( // new session, use their manually specified model
getValidLlmOverride(currentChatSession.current_alternate_model) if (userHasManuallyOverriddenLLM && !currentChatSession) {
);
} else if (user?.preferences?.default_model) {
setLlmOverride(getValidLlmOverride(user.preferences.default_model));
return; return;
}
if (currentChatSession?.current_alternate_model) {
setCurrentLlm(
getValidLlmDescriptor(currentChatSession.current_alternate_model)
);
} else if (liveAssistant?.llm_model_version_override) {
setCurrentLlm(
getValidLlmDescriptor(liveAssistant.llm_model_version_override)
);
} else if (userHasManuallyOverriddenLLM) {
// if the user has an override and there's nothing special about the
// current chat session, use the override
return;
} else if (user?.preferences?.default_model) {
setCurrentLlm(getValidLlmDescriptor(user.preferences.default_model));
} else { } else {
const defaultProvider = llmProviders.find( const defaultProvider = llmProviders.find(
(provider) => provider.is_default_provider (provider) => provider.is_default_provider
); );
if (defaultProvider) { if (defaultProvider) {
setLlmOverride({ setCurrentLlm({
name: defaultProvider.name, name: defaultProvider.name,
provider: defaultProvider.provider, provider: defaultProvider.provider,
modelName: defaultProvider.default_model_name, modelName: defaultProvider.default_model_name,
}); });
} }
} }
};
_llmUpdate();
setChatSession(currentChatSession || null); setChatSession(currentChatSession || null);
}; };
const getValidLlmOverride = ( const getValidLlmDescriptor = (
overrideModel: string | null | undefined modelName: string | null | undefined
): LlmOverride => { ): LlmDescriptor => {
if (overrideModel) { if (modelName) {
const model = destructureValue(overrideModel); const model = destructureValue(modelName);
if (!(model.modelName && model.modelName.length > 0)) { if (!(model.modelName && model.modelName.length > 0)) {
const provider = llmProviders.find((p) => const provider = llmProviders.find((p) =>
p.model_names.includes(overrideModel) p.model_names.includes(modelName)
); );
if (provider) { if (provider) {
return { return {
modelName: overrideModel, modelName: modelName,
name: provider.name, name: provider.name,
provider: provider.provider, provider: provider.provider,
}; };
@ -491,38 +514,32 @@ export function useLlmOverride(
setImageFilesPresent(present); setImageFilesPresent(present);
}; };
const [llmOverride, setLlmOverride] = useState<LlmOverride>({ // Manually set the LLM
name: "", const updateCurrentLlm = (newLlm: LlmDescriptor) => {
provider: "",
modelName: "",
});
// Manually set the override
const updateLLMOverride = (newOverride: LlmOverride) => {
const provider = const provider =
newOverride.provider || newLlm.provider || findProviderForModel(llmProviders, newLlm.modelName);
findProviderForModel(llmProviders, newOverride.modelName);
const structuredValue = structureValue( const structuredValue = structureValue(
newOverride.name, newLlm.name,
provider, provider,
newOverride.modelName newLlm.modelName
); );
setLlmOverride(getValidLlmOverride(structuredValue)); setCurrentLlm(getValidLlmDescriptor(structuredValue));
setUserHasManuallyOverriddenLLM(true);
}; };
const updateModelOverrideForChatSession = (chatSession?: ChatSession) => { const updateModelOverrideBasedOnChatSession = (chatSession?: ChatSession) => {
if (chatSession && chatSession.current_alternate_model?.length > 0) { if (chatSession && chatSession.current_alternate_model?.length > 0) {
setLlmOverride(getValidLlmOverride(chatSession.current_alternate_model)); setCurrentLlm(getValidLlmDescriptor(chatSession.current_alternate_model));
} }
}; };
const [temperature, setTemperature] = useState<number>(() => { const [temperature, setTemperature] = useState<number>(() => {
llmOverrideUpdate(); llmUpdate();
if (currentChatSession?.current_temperature_override != null) { if (currentChatSession?.current_temperature_override != null) {
return Math.min( return Math.min(
currentChatSession.current_temperature_override, currentChatSession.current_temperature_override,
isAnthropic(llmOverride.provider, llmOverride.modelName) ? 1.0 : 2.0 isAnthropic(currentLlm.provider, currentLlm.modelName) ? 1.0 : 2.0
); );
} else if ( } else if (
liveAssistant?.tools.some((tool) => tool.name === SEARCH_TOOL_ID) liveAssistant?.tools.some((tool) => tool.name === SEARCH_TOOL_ID)
@ -533,22 +550,23 @@ export function useLlmOverride(
}); });
const maxTemperature = useMemo(() => { const maxTemperature = useMemo(() => {
return isAnthropic(llmOverride.provider, llmOverride.modelName) ? 1.0 : 2.0; return isAnthropic(currentLlm.provider, currentLlm.modelName) ? 1.0 : 2.0;
}, [llmOverride]); }, [currentLlm]);
useEffect(() => { useEffect(() => {
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) { if (isAnthropic(currentLlm.provider, currentLlm.modelName)) {
const newTemperature = Math.min(temperature, 1.0); const newTemperature = Math.min(temperature, 1.0);
setTemperature(newTemperature); setTemperature(newTemperature);
if (chatSession?.id) { if (chatSession?.id) {
updateTemperatureOverrideForChatSession(chatSession.id, newTemperature); updateTemperatureOverrideForChatSession(chatSession.id, newTemperature);
} }
} }
}, [llmOverride]); }, [currentLlm]);
useEffect(() => { useEffect(() => {
llmUpdate();
if (!chatSession && currentChatSession) { if (!chatSession && currentChatSession) {
setChatSession(currentChatSession || null);
if (temperature) { if (temperature) {
updateTemperatureOverrideForChatSession( updateTemperatureOverrideForChatSession(
currentChatSession.id, currentChatSession.id,
@ -570,7 +588,7 @@ export function useLlmOverride(
}, [liveAssistant, currentChatSession]); }, [liveAssistant, currentChatSession]);
const updateTemperature = (temperature: number) => { const updateTemperature = (temperature: number) => {
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) { if (isAnthropic(currentLlm.provider, currentLlm.modelName)) {
setTemperature((prevTemp) => Math.min(temperature, 1.0)); setTemperature((prevTemp) => Math.min(temperature, 1.0));
} else { } else {
setTemperature(temperature); setTemperature(temperature);
@ -581,9 +599,9 @@ export function useLlmOverride(
}; };
return { return {
updateModelOverrideForChatSession, updateModelOverrideBasedOnChatSession,
llmOverride, currentLlm,
updateLLMOverride, updateCurrentLlm,
temperature, temperature,
updateTemperature, updateTemperature,
imageFilesPresent, imageFilesPresent,

View File

@ -1,11 +1,11 @@
import { Persona } from "@/app/admin/assistants/interfaces"; import { Persona } from "@/app/admin/assistants/interfaces";
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces"; import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
import { LlmOverride } from "@/lib/hooks"; import { LlmDescriptor } from "@/lib/hooks";
export function getFinalLLM( export function getFinalLLM(
llmProviders: LLMProviderDescriptor[], llmProviders: LLMProviderDescriptor[],
persona: Persona | null, persona: Persona | null,
llmOverride: LlmOverride | null currentLlm: LlmDescriptor | null
): [string, string] { ): [string, string] {
const defaultProvider = llmProviders.find( const defaultProvider = llmProviders.find(
(llmProvider) => llmProvider.is_default_provider (llmProvider) => llmProvider.is_default_provider
@ -26,9 +26,9 @@ export function getFinalLLM(
model = persona.llm_model_version_override || model; model = persona.llm_model_version_override || model;
} }
if (llmOverride) { if (currentLlm) {
provider = llmOverride.provider || provider; provider = currentLlm.provider || provider;
model = llmOverride.modelName || model; model = currentLlm.modelName || model;
} }
return [provider, model]; return [provider, model];
@ -37,7 +37,7 @@ export function getFinalLLM(
export function getLLMProviderOverrideForPersona( export function getLLMProviderOverrideForPersona(
liveAssistant: Persona, liveAssistant: Persona,
llmProviders: LLMProviderDescriptor[] llmProviders: LLMProviderDescriptor[]
): LlmOverride | null { ): LlmDescriptor | null {
const overrideProvider = liveAssistant.llm_model_provider_override; const overrideProvider = liveAssistant.llm_model_provider_override;
const overrideModel = liveAssistant.llm_model_version_override; const overrideModel = liveAssistant.llm_model_version_override;
@ -135,7 +135,7 @@ export const structureValue = (
return `${name}__${provider}__${modelName}`; return `${name}__${provider}__${modelName}`;
}; };
export const destructureValue = (value: string): LlmOverride => { export const destructureValue = (value: string): LlmDescriptor => {
const [displayName, provider, modelName] = value.split("__"); const [displayName, provider, modelName] = value.split("__");
return { return {
name: displayName, name: displayName,

View File

@ -1,5 +1,3 @@
import { LlmOverride } from "../hooks";
export async function setUserDefaultModel( export async function setUserDefaultModel(
model: string | null model: string | null
): Promise<Response> { ): Promise<Response> {