Temperature (#3310)

* fix temperatures for default llm

* ensure anthropic models don't overflow

* minor cleanup

* k

* k

* k

* fix typing
This commit is contained in:
pablodanswer
2024-12-03 09:22:22 -08:00
committed by GitHub
parent 6c2269e565
commit cd5f2293ad
8 changed files with 50 additions and 59 deletions

View File

@ -71,6 +71,7 @@ def get_llms_for_persona(
api_base=llm_provider.api_base, api_base=llm_provider.api_base,
api_version=llm_provider.api_version, api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config, custom_config=llm_provider.custom_config,
temperature=temperature_override,
additional_headers=additional_headers, additional_headers=additional_headers,
long_term_logger=long_term_logger, long_term_logger=long_term_logger,
) )
@ -128,11 +129,13 @@ def get_llm(
api_base: str | None = None, api_base: str | None = None,
api_version: str | None = None, api_version: str | None = None,
custom_config: dict[str, str] | None = None, custom_config: dict[str, str] | None = None,
temperature: float = GEN_AI_TEMPERATURE, temperature: float | None = None,
timeout: int = QA_TIMEOUT, timeout: int = QA_TIMEOUT,
additional_headers: dict[str, str] | None = None, additional_headers: dict[str, str] | None = None,
long_term_logger: LongTermLogger | None = None, long_term_logger: LongTermLogger | None = None,
) -> LLM: ) -> LLM:
if temperature is None:
temperature = GEN_AI_TEMPERATURE
return DefaultMultiLLM( return DefaultMultiLLM(
model_provider=provider, model_provider=provider,
model_name=model, model_name=model,

View File

@ -89,3 +89,6 @@ export const getProviderIcon = (providerName: string, modelName?: string) => {
return CPUIcon; return CPUIcon;
} }
}; };
export const isAnthropic = (provider: string, modelName: string) =>
provider === "anthropic" || modelName.toLowerCase().includes("claude");

View File

@ -411,7 +411,7 @@ export function ChatPage({
// reset LLM overrides (based on chat session!) // reset LLM overrides (based on chat session!)
llmOverrideManager.updateModelOverrideForChatSession(selectedChatSession); llmOverrideManager.updateModelOverrideForChatSession(selectedChatSession);
llmOverrideManager.setTemperature(null); llmOverrideManager.updateTemperature(null);
// remove uploaded files // remove uploaded files
setCurrentMessageFiles([]); setCurrentMessageFiles([]);

View File

@ -14,7 +14,6 @@ import { destructureValue, getFinalLLM, structureValue } from "@/lib/llm/utils";
import { useState } from "react"; import { useState } from "react";
import { Hoverable } from "@/components/Hoverable"; import { Hoverable } from "@/components/Hoverable";
import { Popover } from "@/components/popover/Popover"; import { Popover } from "@/components/popover/Popover";
import { StarFeedback } from "@/components/icons/icons";
import { IconType } from "react-icons"; import { IconType } from "react-icons";
import { FiRefreshCw } from "react-icons/fi"; import { FiRefreshCw } from "react-icons/fi";

View File

@ -35,25 +35,9 @@ export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
checkPersonaRequiresImageGeneration(currentAssistant); checkPersonaRequiresImageGeneration(currentAssistant);
const { llmProviders } = useChatContext(); const { llmProviders } = useChatContext();
const { setLlmOverride, temperature, setTemperature } = llmOverrideManager; const { setLlmOverride, temperature, updateTemperature } =
llmOverrideManager;
const [isTemperatureExpanded, setIsTemperatureExpanded] = useState(false); const [isTemperatureExpanded, setIsTemperatureExpanded] = useState(false);
const [localTemperature, setLocalTemperature] = useState<number>(
temperature || 0
);
const debouncedSetTemperature = useCallback(
(value: number) => {
const debouncedFunction = debounce((value: number) => {
setTemperature(value);
}, 300);
return debouncedFunction(value);
},
[setTemperature]
);
const handleTemperatureChange = (value: number) => {
setLocalTemperature(value);
debouncedSetTemperature(value);
};
return ( return (
<div className="w-full"> <div className="w-full">
@ -108,26 +92,26 @@ export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
<input <input
type="range" type="range"
onChange={(e) => onChange={(e) =>
handleTemperatureChange(parseFloat(e.target.value)) updateTemperature(parseFloat(e.target.value))
} }
className="w-full p-2 border border-border rounded-md" className="w-full p-2 border border-border rounded-md"
min="0" min="0"
max="2" max="2"
step="0.01" step="0.01"
value={localTemperature} value={temperature || 0}
/> />
<div <div
className="absolute text-sm" className="absolute text-sm"
style={{ style={{
left: `${(localTemperature || 0) * 50}%`, left: `${(temperature || 0) * 50}%`,
transform: `translateX(-${Math.min( transform: `translateX(-${Math.min(
Math.max((localTemperature || 0) * 50, 10), Math.max((temperature || 0) * 50, 10),
90 90
)}%)`, )}%)`,
top: "-1.5rem", top: "-1.5rem",
}} }}
> >
{localTemperature} {temperature}
</div> </div>
</div> </div>
</> </>

View File

@ -46,7 +46,7 @@ const AssistantSelector = ({
liveAssistant: Persona; liveAssistant: Persona;
onAssistantChange: (assistant: Persona) => void; onAssistantChange: (assistant: Persona) => void;
chatSessionId?: string; chatSessionId?: string;
llmOverrideManager?: LlmOverrideManager; llmOverrideManager: LlmOverrideManager;
isMobile: boolean; isMobile: boolean;
}) => { }) => {
const { finalAssistants } = useAssistants(); const { finalAssistants } = useAssistants();
@ -54,11 +54,9 @@ const AssistantSelector = ({
const dropdownRef = useRef<HTMLDivElement>(null); const dropdownRef = useRef<HTMLDivElement>(null);
const { llmProviders } = useChatContext(); const { llmProviders } = useChatContext();
const { user } = useUser(); const { user } = useUser();
const [assistants, setAssistants] = useState<Persona[]>(finalAssistants); const [assistants, setAssistants] = useState<Persona[]>(finalAssistants);
const [isTemperatureExpanded, setIsTemperatureExpanded] = useState(false); const [isTemperatureExpanded, setIsTemperatureExpanded] = useState(false);
const [localTemperature, setLocalTemperature] = useState<number>(
llmOverrideManager?.temperature || 0
);
// Initialize selectedTab from localStorage // Initialize selectedTab from localStorage
const [selectedTab, setSelectedTab] = useState<number>(() => { const [selectedTab, setSelectedTab] = useState<number>(() => {
@ -92,21 +90,6 @@ const AssistantSelector = ({
} }
}; };
const debouncedSetTemperature = useCallback(
(value: number) => {
const debouncedFunction = debounce((value: number) => {
llmOverrideManager?.setTemperature(value);
}, 300);
return debouncedFunction(value);
},
[llmOverrideManager]
);
const handleTemperatureChange = (value: number) => {
setLocalTemperature(value);
debouncedSetTemperature(value);
};
// Handle tab change and update localStorage // Handle tab change and update localStorage
const handleTabChange = (index: number) => { const handleTabChange = (index: number) => {
setSelectedTab(index); setSelectedTab(index);
@ -119,7 +102,7 @@ const AssistantSelector = ({
const [_, currentLlm] = getFinalLLM( const [_, currentLlm] = getFinalLLM(
llmProviders, llmProviders,
liveAssistant, liveAssistant,
llmOverrideManager?.llmOverride ?? null llmOverrideManager.llmOverride ?? null
); );
const requiresImageGeneration = const requiresImageGeneration =
@ -204,11 +187,10 @@ const AssistantSelector = ({
llmProviders={llmProviders} llmProviders={llmProviders}
currentLlm={currentLlm} currentLlm={currentLlm}
userDefault={userDefaultModel} userDefault={userDefaultModel}
includeUserDefault={true}
onSelect={(value: string | null) => { onSelect={(value: string | null) => {
if (value == null) return; if (value == null) return;
const { modelName, name, provider } = destructureValue(value); const { modelName, name, provider } = destructureValue(value);
llmOverrideManager?.setLlmOverride({ llmOverrideManager.setLlmOverride({
name, name,
provider, provider,
modelName, modelName,
@ -216,7 +198,6 @@ const AssistantSelector = ({
if (chatSessionId) { if (chatSessionId) {
updateModelOverrideForChatSession(chatSessionId, value); updateModelOverrideForChatSession(chatSessionId, value);
} }
setIsOpen(false);
}} }}
/> />
<div className="mt-4"> <div className="mt-4">
@ -243,26 +224,31 @@ const AssistantSelector = ({
<input <input
type="range" type="range"
onChange={(e) => onChange={(e) =>
handleTemperatureChange(parseFloat(e.target.value)) llmOverrideManager.updateTemperature(
parseFloat(e.target.value)
)
} }
className="w-full p-2 border border-border rounded-md" className="w-full p-2 border border-border rounded-md"
min="0" min="0"
max="2" max="2"
step="0.01" step="0.01"
value={localTemperature} value={llmOverrideManager.temperature?.toString() || "0"}
/> />
<div <div
className="absolute text-sm" className="absolute text-sm"
style={{ style={{
left: `${(localTemperature || 0) * 50}%`, left: `${(llmOverrideManager.temperature || 0) * 50}%`,
transform: `translateX(-${Math.min( transform: `translateX(-${Math.min(
Math.max((localTemperature || 0) * 50, 10), Math.max(
(llmOverrideManager.temperature || 0) * 50,
10
),
90 90
)}%)`, )}%)`,
top: "-1.5rem", top: "-1.5rem",
}} }}
> >
{localTemperature} {llmOverrideManager.temperature}
</div> </div>
</div> </div>
</> </>

View File

@ -19,7 +19,6 @@ interface LlmListProps {
scrollable?: boolean; scrollable?: boolean;
hideProviderIcon?: boolean; hideProviderIcon?: boolean;
requiresImageGeneration?: boolean; requiresImageGeneration?: boolean;
includeUserDefault?: boolean;
currentAssistant?: Persona; currentAssistant?: Persona;
} }
@ -31,7 +30,6 @@ export const LlmList: React.FC<LlmListProps> = ({
userDefault, userDefault,
scrollable, scrollable,
requiresImageGeneration, requiresImageGeneration,
includeUserDefault = false,
}) => { }) => {
const llmOptionsByProvider: { const llmOptionsByProvider: {
[provider: string]: { [provider: string]: {

View File

@ -16,6 +16,7 @@ import { UsersResponse } from "./users/interfaces";
import { Credential } from "./connectors/credentials"; import { Credential } from "./connectors/credentials";
import { SettingsContext } from "@/components/settings/SettingsProvider"; import { SettingsContext } from "@/components/settings/SettingsProvider";
import { PersonaCategory } from "@/app/admin/assistants/interfaces"; import { PersonaCategory } from "@/app/admin/assistants/interfaces";
import { isAnthropic } from "@/app/admin/configuration/llm/interfaces";
const CREDENTIAL_URL = "/api/manage/admin/credential"; const CREDENTIAL_URL = "/api/manage/admin/credential";
@ -71,7 +72,9 @@ export const useConnectorCredentialIndexingStatus = (
getEditable = false getEditable = false
) => { ) => {
const { mutate } = useSWRConfig(); const { mutate } = useSWRConfig();
const url = `${INDEXING_STATUS_URL}${getEditable ? "?get_editable=true" : ""}`; const url = `${INDEXING_STATUS_URL}${
getEditable ? "?get_editable=true" : ""
}`;
const swrResponse = useSWR<ConnectorIndexingStatus<any, any>[]>( const swrResponse = useSWR<ConnectorIndexingStatus<any, any>[]>(
url, url,
errorHandlingFetcher, errorHandlingFetcher,
@ -157,7 +160,7 @@ export interface LlmOverrideManager {
globalDefault: LlmOverride; globalDefault: LlmOverride;
setGlobalDefault: React.Dispatch<React.SetStateAction<LlmOverride>>; setGlobalDefault: React.Dispatch<React.SetStateAction<LlmOverride>>;
temperature: number | null; temperature: number | null;
setTemperature: React.Dispatch<React.SetStateAction<number | null>>; updateTemperature: (temperature: number | null) => void;
updateModelOverrideForChatSession: (chatSession?: ChatSession) => void; updateModelOverrideForChatSession: (chatSession?: ChatSession) => void;
} }
export function useLlmOverride( export function useLlmOverride(
@ -212,6 +215,20 @@ export function useLlmOverride(
setTemperature(defaultTemperature !== undefined ? defaultTemperature : 0); setTemperature(defaultTemperature !== undefined ? defaultTemperature : 0);
}, [defaultTemperature]); }, [defaultTemperature]);
useEffect(() => {
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) {
setTemperature((prevTemp) => Math.min(prevTemp ?? 0, 1.0));
}
}, [llmOverride]);
const updateTemperature = (temperature: number | null) => {
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) {
setTemperature((prevTemp) => Math.min(temperature ?? 0, 1.0));
} else {
setTemperature(temperature);
}
};
return { return {
updateModelOverrideForChatSession, updateModelOverrideForChatSession,
llmOverride, llmOverride,
@ -219,9 +236,10 @@ export function useLlmOverride(
globalDefault, globalDefault,
setGlobalDefault, setGlobalDefault,
temperature, temperature,
setTemperature, updateTemperature,
}; };
} }
/* /*
EE Only APIs EE Only APIs
*/ */