mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-01 02:00:48 +02:00
Temperature (#3310)
* fix temperatures for default llm * ensure anthropic models don't overflow * minor cleanup * k * k * k * fix typing
This commit is contained in:
@ -71,6 +71,7 @@ def get_llms_for_persona(
|
|||||||
api_base=llm_provider.api_base,
|
api_base=llm_provider.api_base,
|
||||||
api_version=llm_provider.api_version,
|
api_version=llm_provider.api_version,
|
||||||
custom_config=llm_provider.custom_config,
|
custom_config=llm_provider.custom_config,
|
||||||
|
temperature=temperature_override,
|
||||||
additional_headers=additional_headers,
|
additional_headers=additional_headers,
|
||||||
long_term_logger=long_term_logger,
|
long_term_logger=long_term_logger,
|
||||||
)
|
)
|
||||||
@ -128,11 +129,13 @@ def get_llm(
|
|||||||
api_base: str | None = None,
|
api_base: str | None = None,
|
||||||
api_version: str | None = None,
|
api_version: str | None = None,
|
||||||
custom_config: dict[str, str] | None = None,
|
custom_config: dict[str, str] | None = None,
|
||||||
temperature: float = GEN_AI_TEMPERATURE,
|
temperature: float | None = None,
|
||||||
timeout: int = QA_TIMEOUT,
|
timeout: int = QA_TIMEOUT,
|
||||||
additional_headers: dict[str, str] | None = None,
|
additional_headers: dict[str, str] | None = None,
|
||||||
long_term_logger: LongTermLogger | None = None,
|
long_term_logger: LongTermLogger | None = None,
|
||||||
) -> LLM:
|
) -> LLM:
|
||||||
|
if temperature is None:
|
||||||
|
temperature = GEN_AI_TEMPERATURE
|
||||||
return DefaultMultiLLM(
|
return DefaultMultiLLM(
|
||||||
model_provider=provider,
|
model_provider=provider,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
|
@ -89,3 +89,6 @@ export const getProviderIcon = (providerName: string, modelName?: string) => {
|
|||||||
return CPUIcon;
|
return CPUIcon;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const isAnthropic = (provider: string, modelName: string) =>
|
||||||
|
provider === "anthropic" || modelName.toLowerCase().includes("claude");
|
||||||
|
@ -411,7 +411,7 @@ export function ChatPage({
|
|||||||
|
|
||||||
// reset LLM overrides (based on chat session!)
|
// reset LLM overrides (based on chat session!)
|
||||||
llmOverrideManager.updateModelOverrideForChatSession(selectedChatSession);
|
llmOverrideManager.updateModelOverrideForChatSession(selectedChatSession);
|
||||||
llmOverrideManager.setTemperature(null);
|
llmOverrideManager.updateTemperature(null);
|
||||||
|
|
||||||
// remove uploaded files
|
// remove uploaded files
|
||||||
setCurrentMessageFiles([]);
|
setCurrentMessageFiles([]);
|
||||||
|
@ -14,7 +14,6 @@ import { destructureValue, getFinalLLM, structureValue } from "@/lib/llm/utils";
|
|||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
import { Hoverable } from "@/components/Hoverable";
|
import { Hoverable } from "@/components/Hoverable";
|
||||||
import { Popover } from "@/components/popover/Popover";
|
import { Popover } from "@/components/popover/Popover";
|
||||||
import { StarFeedback } from "@/components/icons/icons";
|
|
||||||
import { IconType } from "react-icons";
|
import { IconType } from "react-icons";
|
||||||
import { FiRefreshCw } from "react-icons/fi";
|
import { FiRefreshCw } from "react-icons/fi";
|
||||||
|
|
||||||
|
@ -35,25 +35,9 @@ export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
|
|||||||
checkPersonaRequiresImageGeneration(currentAssistant);
|
checkPersonaRequiresImageGeneration(currentAssistant);
|
||||||
|
|
||||||
const { llmProviders } = useChatContext();
|
const { llmProviders } = useChatContext();
|
||||||
const { setLlmOverride, temperature, setTemperature } = llmOverrideManager;
|
const { setLlmOverride, temperature, updateTemperature } =
|
||||||
|
llmOverrideManager;
|
||||||
const [isTemperatureExpanded, setIsTemperatureExpanded] = useState(false);
|
const [isTemperatureExpanded, setIsTemperatureExpanded] = useState(false);
|
||||||
const [localTemperature, setLocalTemperature] = useState<number>(
|
|
||||||
temperature || 0
|
|
||||||
);
|
|
||||||
const debouncedSetTemperature = useCallback(
|
|
||||||
(value: number) => {
|
|
||||||
const debouncedFunction = debounce((value: number) => {
|
|
||||||
setTemperature(value);
|
|
||||||
}, 300);
|
|
||||||
return debouncedFunction(value);
|
|
||||||
},
|
|
||||||
[setTemperature]
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleTemperatureChange = (value: number) => {
|
|
||||||
setLocalTemperature(value);
|
|
||||||
debouncedSetTemperature(value);
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="w-full">
|
<div className="w-full">
|
||||||
@ -108,26 +92,26 @@ export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
|
|||||||
<input
|
<input
|
||||||
type="range"
|
type="range"
|
||||||
onChange={(e) =>
|
onChange={(e) =>
|
||||||
handleTemperatureChange(parseFloat(e.target.value))
|
updateTemperature(parseFloat(e.target.value))
|
||||||
}
|
}
|
||||||
className="w-full p-2 border border-border rounded-md"
|
className="w-full p-2 border border-border rounded-md"
|
||||||
min="0"
|
min="0"
|
||||||
max="2"
|
max="2"
|
||||||
step="0.01"
|
step="0.01"
|
||||||
value={localTemperature}
|
value={temperature || 0}
|
||||||
/>
|
/>
|
||||||
<div
|
<div
|
||||||
className="absolute text-sm"
|
className="absolute text-sm"
|
||||||
style={{
|
style={{
|
||||||
left: `${(localTemperature || 0) * 50}%`,
|
left: `${(temperature || 0) * 50}%`,
|
||||||
transform: `translateX(-${Math.min(
|
transform: `translateX(-${Math.min(
|
||||||
Math.max((localTemperature || 0) * 50, 10),
|
Math.max((temperature || 0) * 50, 10),
|
||||||
90
|
90
|
||||||
)}%)`,
|
)}%)`,
|
||||||
top: "-1.5rem",
|
top: "-1.5rem",
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{localTemperature}
|
{temperature}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</>
|
</>
|
||||||
|
@ -46,7 +46,7 @@ const AssistantSelector = ({
|
|||||||
liveAssistant: Persona;
|
liveAssistant: Persona;
|
||||||
onAssistantChange: (assistant: Persona) => void;
|
onAssistantChange: (assistant: Persona) => void;
|
||||||
chatSessionId?: string;
|
chatSessionId?: string;
|
||||||
llmOverrideManager?: LlmOverrideManager;
|
llmOverrideManager: LlmOverrideManager;
|
||||||
isMobile: boolean;
|
isMobile: boolean;
|
||||||
}) => {
|
}) => {
|
||||||
const { finalAssistants } = useAssistants();
|
const { finalAssistants } = useAssistants();
|
||||||
@ -54,11 +54,9 @@ const AssistantSelector = ({
|
|||||||
const dropdownRef = useRef<HTMLDivElement>(null);
|
const dropdownRef = useRef<HTMLDivElement>(null);
|
||||||
const { llmProviders } = useChatContext();
|
const { llmProviders } = useChatContext();
|
||||||
const { user } = useUser();
|
const { user } = useUser();
|
||||||
|
|
||||||
const [assistants, setAssistants] = useState<Persona[]>(finalAssistants);
|
const [assistants, setAssistants] = useState<Persona[]>(finalAssistants);
|
||||||
const [isTemperatureExpanded, setIsTemperatureExpanded] = useState(false);
|
const [isTemperatureExpanded, setIsTemperatureExpanded] = useState(false);
|
||||||
const [localTemperature, setLocalTemperature] = useState<number>(
|
|
||||||
llmOverrideManager?.temperature || 0
|
|
||||||
);
|
|
||||||
|
|
||||||
// Initialize selectedTab from localStorage
|
// Initialize selectedTab from localStorage
|
||||||
const [selectedTab, setSelectedTab] = useState<number>(() => {
|
const [selectedTab, setSelectedTab] = useState<number>(() => {
|
||||||
@ -92,21 +90,6 @@ const AssistantSelector = ({
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const debouncedSetTemperature = useCallback(
|
|
||||||
(value: number) => {
|
|
||||||
const debouncedFunction = debounce((value: number) => {
|
|
||||||
llmOverrideManager?.setTemperature(value);
|
|
||||||
}, 300);
|
|
||||||
return debouncedFunction(value);
|
|
||||||
},
|
|
||||||
[llmOverrideManager]
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleTemperatureChange = (value: number) => {
|
|
||||||
setLocalTemperature(value);
|
|
||||||
debouncedSetTemperature(value);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Handle tab change and update localStorage
|
// Handle tab change and update localStorage
|
||||||
const handleTabChange = (index: number) => {
|
const handleTabChange = (index: number) => {
|
||||||
setSelectedTab(index);
|
setSelectedTab(index);
|
||||||
@ -119,7 +102,7 @@ const AssistantSelector = ({
|
|||||||
const [_, currentLlm] = getFinalLLM(
|
const [_, currentLlm] = getFinalLLM(
|
||||||
llmProviders,
|
llmProviders,
|
||||||
liveAssistant,
|
liveAssistant,
|
||||||
llmOverrideManager?.llmOverride ?? null
|
llmOverrideManager.llmOverride ?? null
|
||||||
);
|
);
|
||||||
|
|
||||||
const requiresImageGeneration =
|
const requiresImageGeneration =
|
||||||
@ -204,11 +187,10 @@ const AssistantSelector = ({
|
|||||||
llmProviders={llmProviders}
|
llmProviders={llmProviders}
|
||||||
currentLlm={currentLlm}
|
currentLlm={currentLlm}
|
||||||
userDefault={userDefaultModel}
|
userDefault={userDefaultModel}
|
||||||
includeUserDefault={true}
|
|
||||||
onSelect={(value: string | null) => {
|
onSelect={(value: string | null) => {
|
||||||
if (value == null) return;
|
if (value == null) return;
|
||||||
const { modelName, name, provider } = destructureValue(value);
|
const { modelName, name, provider } = destructureValue(value);
|
||||||
llmOverrideManager?.setLlmOverride({
|
llmOverrideManager.setLlmOverride({
|
||||||
name,
|
name,
|
||||||
provider,
|
provider,
|
||||||
modelName,
|
modelName,
|
||||||
@ -216,7 +198,6 @@ const AssistantSelector = ({
|
|||||||
if (chatSessionId) {
|
if (chatSessionId) {
|
||||||
updateModelOverrideForChatSession(chatSessionId, value);
|
updateModelOverrideForChatSession(chatSessionId, value);
|
||||||
}
|
}
|
||||||
setIsOpen(false);
|
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
<div className="mt-4">
|
<div className="mt-4">
|
||||||
@ -243,26 +224,31 @@ const AssistantSelector = ({
|
|||||||
<input
|
<input
|
||||||
type="range"
|
type="range"
|
||||||
onChange={(e) =>
|
onChange={(e) =>
|
||||||
handleTemperatureChange(parseFloat(e.target.value))
|
llmOverrideManager.updateTemperature(
|
||||||
|
parseFloat(e.target.value)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
className="w-full p-2 border border-border rounded-md"
|
className="w-full p-2 border border-border rounded-md"
|
||||||
min="0"
|
min="0"
|
||||||
max="2"
|
max="2"
|
||||||
step="0.01"
|
step="0.01"
|
||||||
value={localTemperature}
|
value={llmOverrideManager.temperature?.toString() || "0"}
|
||||||
/>
|
/>
|
||||||
<div
|
<div
|
||||||
className="absolute text-sm"
|
className="absolute text-sm"
|
||||||
style={{
|
style={{
|
||||||
left: `${(localTemperature || 0) * 50}%`,
|
left: `${(llmOverrideManager.temperature || 0) * 50}%`,
|
||||||
transform: `translateX(-${Math.min(
|
transform: `translateX(-${Math.min(
|
||||||
Math.max((localTemperature || 0) * 50, 10),
|
Math.max(
|
||||||
|
(llmOverrideManager.temperature || 0) * 50,
|
||||||
|
10
|
||||||
|
),
|
||||||
90
|
90
|
||||||
)}%)`,
|
)}%)`,
|
||||||
top: "-1.5rem",
|
top: "-1.5rem",
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{localTemperature}
|
{llmOverrideManager.temperature}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</>
|
</>
|
||||||
|
@ -19,7 +19,6 @@ interface LlmListProps {
|
|||||||
scrollable?: boolean;
|
scrollable?: boolean;
|
||||||
hideProviderIcon?: boolean;
|
hideProviderIcon?: boolean;
|
||||||
requiresImageGeneration?: boolean;
|
requiresImageGeneration?: boolean;
|
||||||
includeUserDefault?: boolean;
|
|
||||||
currentAssistant?: Persona;
|
currentAssistant?: Persona;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -31,7 +30,6 @@ export const LlmList: React.FC<LlmListProps> = ({
|
|||||||
userDefault,
|
userDefault,
|
||||||
scrollable,
|
scrollable,
|
||||||
requiresImageGeneration,
|
requiresImageGeneration,
|
||||||
includeUserDefault = false,
|
|
||||||
}) => {
|
}) => {
|
||||||
const llmOptionsByProvider: {
|
const llmOptionsByProvider: {
|
||||||
[provider: string]: {
|
[provider: string]: {
|
||||||
|
@ -16,6 +16,7 @@ import { UsersResponse } from "./users/interfaces";
|
|||||||
import { Credential } from "./connectors/credentials";
|
import { Credential } from "./connectors/credentials";
|
||||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||||
import { PersonaCategory } from "@/app/admin/assistants/interfaces";
|
import { PersonaCategory } from "@/app/admin/assistants/interfaces";
|
||||||
|
import { isAnthropic } from "@/app/admin/configuration/llm/interfaces";
|
||||||
|
|
||||||
const CREDENTIAL_URL = "/api/manage/admin/credential";
|
const CREDENTIAL_URL = "/api/manage/admin/credential";
|
||||||
|
|
||||||
@ -71,7 +72,9 @@ export const useConnectorCredentialIndexingStatus = (
|
|||||||
getEditable = false
|
getEditable = false
|
||||||
) => {
|
) => {
|
||||||
const { mutate } = useSWRConfig();
|
const { mutate } = useSWRConfig();
|
||||||
const url = `${INDEXING_STATUS_URL}${getEditable ? "?get_editable=true" : ""}`;
|
const url = `${INDEXING_STATUS_URL}${
|
||||||
|
getEditable ? "?get_editable=true" : ""
|
||||||
|
}`;
|
||||||
const swrResponse = useSWR<ConnectorIndexingStatus<any, any>[]>(
|
const swrResponse = useSWR<ConnectorIndexingStatus<any, any>[]>(
|
||||||
url,
|
url,
|
||||||
errorHandlingFetcher,
|
errorHandlingFetcher,
|
||||||
@ -157,7 +160,7 @@ export interface LlmOverrideManager {
|
|||||||
globalDefault: LlmOverride;
|
globalDefault: LlmOverride;
|
||||||
setGlobalDefault: React.Dispatch<React.SetStateAction<LlmOverride>>;
|
setGlobalDefault: React.Dispatch<React.SetStateAction<LlmOverride>>;
|
||||||
temperature: number | null;
|
temperature: number | null;
|
||||||
setTemperature: React.Dispatch<React.SetStateAction<number | null>>;
|
updateTemperature: (temperature: number | null) => void;
|
||||||
updateModelOverrideForChatSession: (chatSession?: ChatSession) => void;
|
updateModelOverrideForChatSession: (chatSession?: ChatSession) => void;
|
||||||
}
|
}
|
||||||
export function useLlmOverride(
|
export function useLlmOverride(
|
||||||
@ -212,6 +215,20 @@ export function useLlmOverride(
|
|||||||
setTemperature(defaultTemperature !== undefined ? defaultTemperature : 0);
|
setTemperature(defaultTemperature !== undefined ? defaultTemperature : 0);
|
||||||
}, [defaultTemperature]);
|
}, [defaultTemperature]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) {
|
||||||
|
setTemperature((prevTemp) => Math.min(prevTemp ?? 0, 1.0));
|
||||||
|
}
|
||||||
|
}, [llmOverride]);
|
||||||
|
|
||||||
|
const updateTemperature = (temperature: number | null) => {
|
||||||
|
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) {
|
||||||
|
setTemperature((prevTemp) => Math.min(temperature ?? 0, 1.0));
|
||||||
|
} else {
|
||||||
|
setTemperature(temperature);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
return {
|
return {
|
||||||
updateModelOverrideForChatSession,
|
updateModelOverrideForChatSession,
|
||||||
llmOverride,
|
llmOverride,
|
||||||
@ -219,9 +236,10 @@ export function useLlmOverride(
|
|||||||
globalDefault,
|
globalDefault,
|
||||||
setGlobalDefault,
|
setGlobalDefault,
|
||||||
temperature,
|
temperature,
|
||||||
setTemperature,
|
updateTemperature,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
EE Only APIs
|
EE Only APIs
|
||||||
*/
|
*/
|
||||||
|
Reference in New Issue
Block a user