diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 9d0e3fea4..4ab259381 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -76,13 +76,7 @@ import { import { buildFilters } from "@/lib/search/utils"; import { SettingsContext } from "@/components/settings/SettingsProvider"; import Dropzone from "react-dropzone"; -import { - checkLLMSupportsImageInput, - getFinalLLM, - destructureValue, - getLLMProviderOverrideForPersona, -} from "@/lib/llm/utils"; - +import { checkLLMSupportsImageInput, getFinalLLM } from "@/lib/llm/utils"; import { ChatInputBar } from "./input/ChatInputBar"; import { useChatContext } from "@/components/context/ChatContext"; import { v4 as uuidv4 } from "uuid"; @@ -203,6 +197,12 @@ export function ChatPage({ const [showHistorySidebar, setShowHistorySidebar] = useState(false); // State to track if sidebar is open + const existingChatSessionId = existingChatIdRaw ? existingChatIdRaw : null; + + const selectedChatSession = chatSessions.find( + (chatSession) => chatSession.id === existingChatSessionId + ); + useEffect(() => { if (user?.is_anonymous_user) { Cookies.set( @@ -240,12 +240,6 @@ export function ChatPage({ } }; - const existingChatSessionId = existingChatIdRaw ? existingChatIdRaw : null; - - const selectedChatSession = chatSessions.find( - (chatSession) => chatSession.id === existingChatSessionId - ); - const chatSessionIdRef = useRef(existingChatSessionId); // Only updates on session load (ie. rename / switching chat session) @@ -293,12 +287,6 @@ export function ChatPage({ ); }; - const llmOverrideManager = useLlmOverride( - llmProviders, - user?.preferences.default_model, - selectedChatSession - ); - const [alternativeAssistant, setAlternativeAssistant] = useState(null); @@ -307,12 +295,27 @@ export function ChatPage({ const { recentAssistants, refreshRecentAssistants } = useAssistants(); - const liveAssistant: Persona | undefined = - alternativeAssistant || - selectedAssistant || - recentAssistants[0] || - finalAssistants[0] || - availableAssistants[0]; + const liveAssistant: Persona | undefined = useMemo( + () => + alternativeAssistant || + selectedAssistant || + recentAssistants[0] || + finalAssistants[0] || + availableAssistants[0], + [ + alternativeAssistant, + selectedAssistant, + recentAssistants, + finalAssistants, + availableAssistants, + ] + ); + + const llmOverrideManager = useLlmOverride( + llmProviders, + selectedChatSession, + liveAssistant + ); const noAssistants = liveAssistant == null || liveAssistant == undefined; @@ -320,24 +323,6 @@ export function ChatPage({ const uniqueSources = Array.from(new Set(availableSources)); const sources = uniqueSources.map((source) => getSourceMetadata(source)); - // always set the model override for the chat session, when an assistant, llm provider, or user preference exists - useEffect(() => { - if (noAssistants) return; - const personaDefault = getLLMProviderOverrideForPersona( - liveAssistant, - llmProviders - ); - - if (personaDefault) { - llmOverrideManager.updateLLMOverride(personaDefault); - } else if (user?.preferences.default_model) { - llmOverrideManager.updateLLMOverride( - destructureValue(user?.preferences.default_model) - ); - } - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [liveAssistant, user?.preferences.default_model]); - const stopGenerating = () => { const currentSession = currentSessionId(); const controller = abortControllers.get(currentSession); @@ -419,7 +404,6 @@ export function ChatPage({ filterManager.setTimeRange(null); // reset LLM overrides (based on chat session!) - llmOverrideManager.updateModelOverrideForChatSession(selectedChatSession); llmOverrideManager.updateTemperature(null); // remove uploaded files @@ -1283,13 +1267,11 @@ export function ChatPage({ modelProvider: modelOverRide?.name || llmOverrideManager.llmOverride.name || - llmOverrideManager.globalDefault.name || undefined, modelVersion: modelOverRide?.modelName || llmOverrideManager.llmOverride.modelName || searchParams.get(SEARCH_PARAM_NAMES.MODEL_VERSION) || - llmOverrideManager.globalDefault.modelName || undefined, temperature: llmOverrideManager.temperature || undefined, systemPromptOverride: @@ -1952,6 +1934,7 @@ export function ChatPage({ }; // eslint-disable-next-line react-hooks/exhaustive-deps }, [router]); + const [sharedChatSession, setSharedChatSession] = useState(); @@ -2059,7 +2042,9 @@ export function ChatPage({ {(settingsToggled || userSettingsToggled) && ( + llmOverrideManager.updateLLMOverride(newOverride) + } defaultModel={user?.preferences.default_model!} llmProviders={llmProviders} onClose={() => { @@ -2749,6 +2734,7 @@ export function ChatPage({ )} + void; llmProviders: LLMProviderDescriptor[]; - setLlmOverride?: Dispatch>; + setLlmOverride?: (newOverride: LlmOverride) => void; onClose: () => void; defaultModel: string | null; }) { diff --git a/web/src/lib/hooks.ts b/web/src/lib/hooks.ts index 0b67ab9d4..b4ee55747 100644 --- a/web/src/lib/hooks.ts +++ b/web/src/lib/hooks.ts @@ -13,16 +13,21 @@ import { errorHandlingFetcher } from "./fetcher"; import { useContext, useEffect, useState } from "react"; import { DateRangePickerValue } from "@/app/ee/admin/performance/DateRangeSelector"; import { Filters, SourceMetadata } from "./search/interfaces"; -import { destructureValue, structureValue } from "./llm/utils"; +import { + destructureValue, + findProviderForModel, + structureValue, +} from "./llm/utils"; import { ChatSession } from "@/app/chat/interfaces"; import { AllUsersResponse } from "./types"; import { Credential } from "./connectors/credentials"; import { SettingsContext } from "@/components/settings/SettingsProvider"; -import { PersonaLabel } from "@/app/admin/assistants/interfaces"; +import { Persona, PersonaLabel } from "@/app/admin/assistants/interfaces"; import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces"; import { isAnthropic } from "@/app/admin/configuration/llm/interfaces"; import { getSourceMetadata } from "./sources"; import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "./constants"; +import { useUser } from "@/components/user/UserProvider"; const CREDENTIAL_URL = "/api/manage/admin/credential"; @@ -355,82 +360,141 @@ export interface LlmOverride { export interface LlmOverrideManager { llmOverride: LlmOverride; updateLLMOverride: (newOverride: LlmOverride) => void; - globalDefault: LlmOverride; - setGlobalDefault: React.Dispatch>; temperature: number | null; updateTemperature: (temperature: number | null) => void; updateModelOverrideForChatSession: (chatSession?: ChatSession) => void; imageFilesPresent: boolean; updateImageFilesPresent: (present: boolean) => void; + liveAssistant: Persona | null; } + +/* +LLM Override is as follows (i.e. this order) +- User override (explicitly set in the chat input bar) +- User preference (defaults to system wide default if no preference set) + +On switching to an existing or new chat session or a different assistant: +- If we have a live assistant after any switch with a model override, use that- otherwise use the above hierarchy + +Thus, the input should be +- User preference +- LLM Providers (which contain the system wide default) +- Current assistant + +Changes take place as +- liveAssistant or currentChatSession changes (and the associated model override is set) +- (uploadLLMOverride) User explicitly setting a model override (and we explicitly override and set the userSpecifiedOverride which we'll use in place of the user preferences unless overridden by an assistant) + +If we have a live assistant, we should use that model override +*/ + export function useLlmOverride( llmProviders: LLMProviderDescriptor[], - globalModel?: string | null, currentChatSession?: ChatSession, - defaultTemperature?: number + liveAssistant?: Persona ): LlmOverrideManager { + const { user } = useUser(); + + const [chatSession, setChatSession] = useState(null); + + const llmOverrideUpdate = () => { + if (!chatSession && currentChatSession) { + setChatSession(currentChatSession || null); + return; + } + + if (liveAssistant?.llm_model_version_override) { + setLlmOverride( + getValidLlmOverride(liveAssistant.llm_model_version_override) + ); + } else if (currentChatSession?.current_alternate_model) { + setLlmOverride( + getValidLlmOverride(currentChatSession.current_alternate_model) + ); + } else if (user?.preferences?.default_model) { + setLlmOverride(getValidLlmOverride(user.preferences.default_model)); + return; + } else { + const defaultProvider = llmProviders.find( + (provider) => provider.is_default_provider + ); + + if (defaultProvider) { + setLlmOverride({ + name: defaultProvider.name, + provider: defaultProvider.provider, + modelName: defaultProvider.default_model_name, + }); + } + } + setChatSession(currentChatSession || null); + }; + const getValidLlmOverride = ( overrideModel: string | null | undefined ): LlmOverride => { if (overrideModel) { const model = destructureValue(overrideModel); - const provider = llmProviders.find( - (p) => - p.model_names.includes(model.modelName) && - p.provider === model.provider + if (!(model.modelName && model.modelName.length > 0)) { + const provider = llmProviders.find((p) => + p.model_names.includes(overrideModel) + ); + if (provider) { + return { + modelName: overrideModel, + name: provider.name, + provider: provider.provider, + }; + } + } + + const provider = llmProviders.find((p) => + p.model_names.includes(model.modelName) ); + if (provider) { return { ...model, name: provider.name }; } } return { name: "", provider: "", modelName: "" }; }; + const [imageFilesPresent, setImageFilesPresent] = useState(false); const updateImageFilesPresent = (present: boolean) => { setImageFilesPresent(present); }; - const [globalDefault, setGlobalDefault] = useState( - getValidLlmOverride(globalModel) - ); - const updateLLMOverride = (newOverride: LlmOverride) => { - setLlmOverride( - getValidLlmOverride( - structureValue( - newOverride.name, - newOverride.provider, - newOverride.modelName - ) - ) - ); - }; + const [llmOverride, setLlmOverride] = useState({ + name: "", + provider: "", + modelName: "", + }); - const [llmOverride, setLlmOverride] = useState( - currentChatSession && currentChatSession.current_alternate_model - ? getValidLlmOverride(currentChatSession.current_alternate_model) - : { name: "", provider: "", modelName: "" } - ); + // Manually set the override + const updateLLMOverride = (newOverride: LlmOverride) => { + const provider = + newOverride.provider || + findProviderForModel(llmProviders, newOverride.modelName); + const structuredValue = structureValue( + newOverride.name, + provider, + newOverride.modelName + ); + setLlmOverride(getValidLlmOverride(structuredValue)); + }; const updateModelOverrideForChatSession = (chatSession?: ChatSession) => { - setLlmOverride( - chatSession && chatSession.current_alternate_model - ? getValidLlmOverride(chatSession.current_alternate_model) - : globalDefault - ); + if (chatSession && chatSession.current_alternate_model?.length > 0) { + setLlmOverride(getValidLlmOverride(chatSession.current_alternate_model)); + } }; - const [temperature, setTemperature] = useState( - defaultTemperature !== undefined ? defaultTemperature : 0 - ); + const [temperature, setTemperature] = useState(0); useEffect(() => { - setGlobalDefault(getValidLlmOverride(globalModel)); - }, [globalModel, llmProviders]); - - useEffect(() => { - setTemperature(defaultTemperature !== undefined ? defaultTemperature : 0); - }, [defaultTemperature]); + llmOverrideUpdate(); + }, [liveAssistant, currentChatSession]); useEffect(() => { if (isAnthropic(llmOverride.provider, llmOverride.modelName)) { @@ -450,12 +514,11 @@ export function useLlmOverride( updateModelOverrideForChatSession, llmOverride, updateLLMOverride, - globalDefault, - setGlobalDefault, temperature, updateTemperature, imageFilesPresent, updateImageFilesPresent, + liveAssistant: liveAssistant ?? null, }; } diff --git a/web/src/lib/llm/utils.ts b/web/src/lib/llm/utils.ts index 3eca6cacc..1880385e0 100644 --- a/web/src/lib/llm/utils.ts +++ b/web/src/lib/llm/utils.ts @@ -143,3 +143,11 @@ export const destructureValue = (value: string): LlmOverride => { modelName, }; }; + +export const findProviderForModel = ( + llmProviders: LLMProviderDescriptor[], + modelName: string +): string => { + const provider = llmProviders.find((p) => p.model_names.includes(modelName)); + return provider ? provider.provider : ""; +};