diff --git a/backend/alembic/versions/2f80c6a2550f_add_chat_session_specific_temperature_.py b/backend/alembic/versions/2f80c6a2550f_add_chat_session_specific_temperature_.py new file mode 100644 index 000000000..07259ec44 --- /dev/null +++ b/backend/alembic/versions/2f80c6a2550f_add_chat_session_specific_temperature_.py @@ -0,0 +1,36 @@ +"""add chat session specific temperature override + +Revision ID: 2f80c6a2550f +Revises: 33ea50e88f24 +Create Date: 2025-01-31 10:30:27.289646 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "2f80c6a2550f" +down_revision = "33ea50e88f24" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "chat_session", sa.Column("temperature_override", sa.Float(), nullable=True) + ) + op.add_column( + "user", + sa.Column( + "temperature_override_enabled", + sa.Boolean(), + nullable=False, + server_default=sa.false(), + ), + ) + + +def downgrade() -> None: + op.drop_column("chat_session", "temperature_override") + op.drop_column("user", "temperature_override_enabled") diff --git a/backend/onyx/db/models.py b/backend/onyx/db/models.py index cdc96fd54..83747eef8 100644 --- a/backend/onyx/db/models.py +++ b/backend/onyx/db/models.py @@ -150,6 +150,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base): # if specified, controls the assistants that are shown to the user + their order # if not specified, all assistants are shown + temperature_override_enabled: Mapped[bool] = mapped_column(Boolean, default=False) auto_scroll: Mapped[bool] = mapped_column(Boolean, default=True) shortcut_enabled: Mapped[bool] = mapped_column(Boolean, default=True) chosen_assistants: Mapped[list[int] | None] = mapped_column( @@ -1115,6 +1116,10 @@ class ChatSession(Base): llm_override: Mapped[LLMOverride | None] = mapped_column( PydanticType(LLMOverride), nullable=True ) + + # The latest temperature override specified by the user + temperature_override: Mapped[float | None] = mapped_column(Float, nullable=True) + prompt_override: Mapped[PromptOverride | None] = mapped_column( PydanticType(PromptOverride), nullable=True ) diff --git a/backend/onyx/natural_language_processing/search_nlp_models.py b/backend/onyx/natural_language_processing/search_nlp_models.py index 2f6c6d306..5f1f2d59a 100644 --- a/backend/onyx/natural_language_processing/search_nlp_models.py +++ b/backend/onyx/natural_language_processing/search_nlp_models.py @@ -175,7 +175,6 @@ class EmbeddingModel: if self.callback.should_stop(): raise RuntimeError("_batch_encode_texts detected stop signal") - logger.debug(f"Encoding batch {batch_idx} of {len(text_batches)}") embed_request = EmbedRequest( model_name=self.model_name, texts=text_batch, @@ -191,7 +190,15 @@ class EmbeddingModel: api_url=self.api_url, ) + start_time = time.time() response = self._make_model_server_request(embed_request) + end_time = time.time() + + processing_time = end_time - start_time + logger.info( + f"Batch {batch_idx} processing time: {processing_time:.2f} seconds" + ) + return batch_idx, response.embeddings # only multi thread if: diff --git a/backend/onyx/server/manage/models.py b/backend/onyx/server/manage/models.py index 81a04708e..755b11e04 100644 --- a/backend/onyx/server/manage/models.py +++ b/backend/onyx/server/manage/models.py @@ -48,6 +48,7 @@ class UserPreferences(BaseModel): auto_scroll: bool | None = None pinned_assistants: list[int] | None = None shortcut_enabled: bool | None = None + temperature_override_enabled: bool | None = None class UserInfo(BaseModel): @@ -91,6 +92,7 @@ class UserInfo(BaseModel): hidden_assistants=user.hidden_assistants, pinned_assistants=user.pinned_assistants, visible_assistants=user.visible_assistants, + temperature_override_enabled=user.temperature_override_enabled, ) ), organization_name=organization_name, diff --git a/backend/onyx/server/manage/users.py b/backend/onyx/server/manage/users.py index 3988c90e4..3ef0faf39 100644 --- a/backend/onyx/server/manage/users.py +++ b/backend/onyx/server/manage/users.py @@ -568,6 +568,32 @@ def verify_user_logged_in( """APIs to adjust user preferences""" +@router.patch("/temperature-override-enabled") +def update_user_temperature_override_enabled( + temperature_override_enabled: bool, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> None: + if user is None: + if AUTH_TYPE == AuthType.DISABLED: + store = get_kv_store() + no_auth_user = fetch_no_auth_user(store) + no_auth_user.preferences.temperature_override_enabled = ( + temperature_override_enabled + ) + set_no_auth_user_preferences(store, no_auth_user.preferences) + return + else: + raise RuntimeError("This should never happen") + + db_session.execute( + update(User) + .where(User.id == user.id) # type: ignore + .values(temperature_override_enabled=temperature_override_enabled) + ) + db_session.commit() + + class ChosenDefaultModelRequest(BaseModel): default_model: str | None = None diff --git a/backend/onyx/server/query_and_chat/chat_backend.py b/backend/onyx/server/query_and_chat/chat_backend.py index 53032fb46..ad583cd62 100644 --- a/backend/onyx/server/query_and_chat/chat_backend.py +++ b/backend/onyx/server/query_and_chat/chat_backend.py @@ -77,6 +77,7 @@ from onyx.server.query_and_chat.models import LLMOverride from onyx.server.query_and_chat.models import PromptOverride from onyx.server.query_and_chat.models import RenameChatSessionResponse from onyx.server.query_and_chat.models import SearchFeedbackRequest +from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest from onyx.server.query_and_chat.token_limit import check_token_rate_limits from onyx.utils.headers import get_custom_tool_additional_request_headers @@ -114,12 +115,52 @@ def get_user_chat_sessions( shared_status=chat.shared_status, folder_id=chat.folder_id, current_alternate_model=chat.current_alternate_model, + current_temperature_override=chat.temperature_override, ) for chat in chat_sessions ] ) +@router.put("/update-chat-session-temperature") +def update_chat_session_temperature( + update_thread_req: UpdateChatSessionTemperatureRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> None: + chat_session = get_chat_session_by_id( + chat_session_id=update_thread_req.chat_session_id, + user_id=user.id if user is not None else None, + db_session=db_session, + ) + + # Validate temperature_override + if update_thread_req.temperature_override is not None: + if ( + update_thread_req.temperature_override < 0 + or update_thread_req.temperature_override > 2 + ): + raise HTTPException( + status_code=400, detail="Temperature must be between 0 and 2" + ) + + # Additional check for Anthropic models + if ( + chat_session.current_alternate_model + and "anthropic" in chat_session.current_alternate_model.lower() + ): + if update_thread_req.temperature_override > 1: + raise HTTPException( + status_code=400, + detail="Temperature for Anthropic models must be between 0 and 1", + ) + + chat_session.temperature_override = update_thread_req.temperature_override + + db_session.add(chat_session) + db_session.commit() + + @router.put("/update-chat-session-model") def update_chat_session_model( update_thread_req: UpdateChatSessionThreadRequest, @@ -190,6 +231,7 @@ def get_chat_session( ], time_created=chat_session.time_created, shared_status=chat_session.shared_status, + current_temperature_override=chat_session.temperature_override, ) diff --git a/backend/onyx/server/query_and_chat/models.py b/backend/onyx/server/query_and_chat/models.py index e44982f75..b107460bf 100644 --- a/backend/onyx/server/query_and_chat/models.py +++ b/backend/onyx/server/query_and_chat/models.py @@ -42,6 +42,11 @@ class UpdateChatSessionThreadRequest(BaseModel): new_alternate_model: str +class UpdateChatSessionTemperatureRequest(BaseModel): + chat_session_id: UUID + temperature_override: float + + class ChatSessionCreationRequest(BaseModel): # If not specified, use Onyx default persona persona_id: int = 0 @@ -108,6 +113,10 @@ class CreateChatMessageRequest(ChunkContext): llm_override: LLMOverride | None = None prompt_override: PromptOverride | None = None + # Allows the caller to override the temperature for the chat session + # this does persist in the chat thread details + temperature_override: float | None = None + # allow user to specify an alternate assistnat alternate_assistant_id: int | None = None @@ -168,6 +177,7 @@ class ChatSessionDetails(BaseModel): shared_status: ChatSessionSharedStatus folder_id: int | None = None current_alternate_model: str | None = None + current_temperature_override: float | None = None class ChatSessionsResponse(BaseModel): @@ -231,6 +241,7 @@ class ChatSessionDetailResponse(BaseModel): time_created: datetime shared_status: ChatSessionSharedStatus current_alternate_model: str | None + current_temperature_override: float | None # This one is not used anymore diff --git a/web/package-lock.json b/web/package-lock.json index 943cc05d3..d5a074eb1 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -25,6 +25,7 @@ "@radix-ui/react-scroll-area": "^1.2.2", "@radix-ui/react-select": "^2.1.2", "@radix-ui/react-separator": "^1.1.0", + "@radix-ui/react-slider": "^1.2.2", "@radix-ui/react-slot": "^1.1.0", "@radix-ui/react-switch": "^1.1.1", "@radix-ui/react-tabs": "^1.1.1", @@ -4963,6 +4964,142 @@ } } }, + "node_modules/@radix-ui/react-slider": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-slider/-/react-slider-1.2.2.tgz", + "integrity": "sha512-sNlU06ii1/ZcbHf8I9En54ZPW0Vil/yPVg4vQMcFNjrIx51jsHbFl1HYHQvCIWJSr1q0ZmA+iIs/ZTv8h7HHSA==", + "license": "MIT", + "dependencies": { + "@radix-ui/number": "1.1.0", + "@radix-ui/primitive": "1.1.1", + "@radix-ui/react-collection": "1.1.1", + "@radix-ui/react-compose-refs": "1.1.1", + "@radix-ui/react-context": "1.1.1", + "@radix-ui/react-direction": "1.1.0", + "@radix-ui/react-primitive": "2.0.1", + "@radix-ui/react-use-controllable-state": "1.1.0", + "@radix-ui/react-use-layout-effect": "1.1.0", + "@radix-ui/react-use-previous": "1.1.0", + "@radix-ui/react-use-size": "1.1.0" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-slider/node_modules/@radix-ui/primitive": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/primitive/-/primitive-1.1.1.tgz", + "integrity": "sha512-SJ31y+Q/zAyShtXJc8x83i9TYdbAfHZ++tUZnvjJJqFjzsdUnKsxPL6IEtBlxKkU7yzer//GQtZSV4GbldL3YA==", + "license": "MIT" + }, + "node_modules/@radix-ui/react-slider/node_modules/@radix-ui/react-collection": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-collection/-/react-collection-1.1.1.tgz", + "integrity": "sha512-LwT3pSho9Dljg+wY2KN2mrrh6y3qELfftINERIzBUO9e0N+t0oMTyn3k9iv+ZqgrwGkRnLpNJrsMv9BZlt2yuA==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-compose-refs": "1.1.1", + "@radix-ui/react-context": "1.1.1", + "@radix-ui/react-primitive": "2.0.1", + "@radix-ui/react-slot": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-slider/node_modules/@radix-ui/react-compose-refs": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-compose-refs/-/react-compose-refs-1.1.1.tgz", + "integrity": "sha512-Y9VzoRDSJtgFMUCoiZBDVo084VQ5hfpXxVE+NgkdNsjiDBByiImMZKKhxMwCbdHvhlENG6a833CbFkOQvTricw==", + "license": "MIT", + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-slider/node_modules/@radix-ui/react-context": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-context/-/react-context-1.1.1.tgz", + "integrity": "sha512-UASk9zi+crv9WteK/NU4PLvOoL3OuE6BWVKNF6hPRBtYBDXQ2u5iu3O59zUlJiTVvkyuycnqrztsHVJwcK9K+Q==", + "license": "MIT", + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-slider/node_modules/@radix-ui/react-primitive": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.0.1.tgz", + "integrity": "sha512-sHCWTtxwNn3L3fH8qAfnF3WbUZycW93SM1j3NFDzXBiz8D6F5UTTy8G1+WFEaiCdvCVRJWj6N2R4Xq6HdiHmDg==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-slot": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-slider/node_modules/@radix-ui/react-slot": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.1.1.tgz", + "integrity": "sha512-RApLLOcINYJA+dMVbOju7MYv1Mb2EBp2nH4HdDzXTSyaR5optlm6Otrz1euW3HbdOR8UmmFK06TD+A9frYWv+g==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-compose-refs": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-slot": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.1.0.tgz", diff --git a/web/package.json b/web/package.json index ec861782b..476880a99 100644 --- a/web/package.json +++ b/web/package.json @@ -28,6 +28,7 @@ "@radix-ui/react-scroll-area": "^1.2.2", "@radix-ui/react-select": "^2.1.2", "@radix-ui/react-separator": "^1.1.0", + "@radix-ui/react-slider": "^1.2.2", "@radix-ui/react-slot": "^1.1.0", "@radix-ui/react-switch": "^1.1.1", "@radix-ui/react-tabs": "^1.1.1", diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 376a4bff2..49ea64315 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -404,9 +404,6 @@ export function ChatPage({ filterManager.setSelectedTags([]); filterManager.setTimeRange(null); - // reset LLM overrides (based on chat session!) - llmOverrideManager.updateTemperature(null); - // remove uploaded files setCurrentMessageFiles([]); @@ -449,6 +446,7 @@ export function ChatPage({ ); const chatSession = (await response.json()) as BackendChatSession; + setSelectedAssistantFromId(chatSession.persona_id); const newMessageMap = processRawChatHistory(chatSession.messages); diff --git a/web/src/app/chat/input/LLMPopover.tsx b/web/src/app/chat/input/LLMPopover.tsx index 36589c550..9cdab8b57 100644 --- a/web/src/app/chat/input/LLMPopover.tsx +++ b/web/src/app/chat/input/LLMPopover.tsx @@ -1,4 +1,4 @@ -import React, { useState } from "react"; +import React, { useState, useEffect } from "react"; import { Popover, PopoverContent, @@ -26,6 +26,9 @@ import { } from "@/components/ui/tooltip"; import { FiAlertTriangle } from "react-icons/fi"; +import { Slider } from "@/components/ui/slider"; +import { useUser } from "@/components/user/UserProvider"; + interface LLMPopoverProps { llmProviders: LLMProviderDescriptor[]; llmOverrideManager: LlmOverrideManager; @@ -40,6 +43,7 @@ export default function LLMPopover({ currentAssistant, }: LLMPopoverProps) { const [isOpen, setIsOpen] = useState(false); + const { user } = useUser(); const { llmOverride, updateLLMOverride } = llmOverrideManager; const currentLlm = llmOverride.modelName; @@ -88,6 +92,22 @@ export default function LLMPopover({ ? getDisplayNameForModel(defaultModelName) : null; + const [localTemperature, setLocalTemperature] = useState( + llmOverrideManager.temperature ?? 0.5 + ); + + useEffect(() => { + setLocalTemperature(llmOverrideManager.temperature ?? 0.5); + }, [llmOverrideManager.temperature]); + + const handleTemperatureChange = (value: number[]) => { + setLocalTemperature(value[0]); + }; + + const handleTemperatureChangeComplete = (value: number[]) => { + llmOverrideManager.updateTemperature(value[0]); + }; + return ( @@ -118,9 +138,9 @@ export default function LLMPopover({ -
+
{llmOptions.map(({ name, icon, value }, index) => { if (!requiresImageGeneration || checkLLMSupportsImageInput(name)) { return ( @@ -171,6 +191,25 @@ export default function LLMPopover({ return null; })}
+ {user?.preferences?.temperature_override_enabled && ( +
+
+ +
+ Temperature (creativity) + {localTemperature.toFixed(1)} +
+
+
+ )} ); diff --git a/web/src/app/chat/interfaces.ts b/web/src/app/chat/interfaces.ts index df6546d11..337a5897c 100644 --- a/web/src/app/chat/interfaces.ts +++ b/web/src/app/chat/interfaces.ts @@ -68,6 +68,7 @@ export interface ChatSession { shared_status: ChatSessionSharedStatus; folder_id: number | null; current_alternate_model: string; + current_temperature_override: number | null; } export interface SearchSession { @@ -107,6 +108,7 @@ export interface BackendChatSession { messages: BackendMessage[]; time_created: string; shared_status: ChatSessionSharedStatus; + current_temperature_override: number | null; current_alternate_model?: string; } diff --git a/web/src/app/chat/lib.tsx b/web/src/app/chat/lib.tsx index e5aa62768..50f9614c1 100644 --- a/web/src/app/chat/lib.tsx +++ b/web/src/app/chat/lib.tsx @@ -75,6 +75,23 @@ export async function updateModelOverrideForChatSession( return response; } +export async function updateTemperatureOverrideForChatSession( + chatSessionId: string, + newTemperature: number +) { + const response = await fetch("/api/chat/update-chat-session-temperature", { + method: "PUT", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + chat_session_id: chatSessionId, + temperature_override: newTemperature, + }), + }); + return response; +} + export async function createChatSession( personaId: number, description: string | null diff --git a/web/src/app/chat/modal/UserSettingsModal.tsx b/web/src/app/chat/modal/UserSettingsModal.tsx index 17342126f..d5fccf3dd 100644 --- a/web/src/app/chat/modal/UserSettingsModal.tsx +++ b/web/src/app/chat/modal/UserSettingsModal.tsx @@ -30,8 +30,13 @@ export function UserSettingsModal({ defaultModel: string | null; }) { const { inputPrompts, refreshInputPrompts } = useChatContext(); - const { refreshUser, user, updateUserAutoScroll, updateUserShortcuts } = - useUser(); + const { + refreshUser, + user, + updateUserAutoScroll, + updateUserShortcuts, + updateUserTemperatureOverrideEnabled, + } = useUser(); const containerRef = useRef(null); const messageRef = useRef(null); @@ -179,6 +184,16 @@ export function UserSettingsModal({ />
+
+ { + updateUserTemperatureOverrideEnabled(checked); + }} + /> + +
diff --git a/web/src/components/Modal.tsx b/web/src/components/Modal.tsx index f32673c8a..f926a399d 100644 --- a/web/src/components/Modal.tsx +++ b/web/src/components/Modal.tsx @@ -103,7 +103,7 @@ export function Modal({ )} -
+
{title && ( <>
diff --git a/web/src/components/ui/slider.tsx b/web/src/components/ui/slider.tsx new file mode 100644 index 000000000..0f8d16f5f --- /dev/null +++ b/web/src/components/ui/slider.tsx @@ -0,0 +1,28 @@ +"use client"; + +import * as React from "react"; +import * as SliderPrimitive from "@radix-ui/react-slider"; + +import { cn } from "@/lib/utils"; + +const Slider = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + + + + + + +)); +Slider.displayName = SliderPrimitive.Root.displayName; + +export { Slider }; diff --git a/web/src/components/user/UserProvider.tsx b/web/src/components/user/UserProvider.tsx index 8372ef268..701e3181a 100644 --- a/web/src/components/user/UserProvider.tsx +++ b/web/src/components/user/UserProvider.tsx @@ -18,6 +18,7 @@ interface UserContextType { assistantId: number, isPinned: boolean ) => Promise; + updateUserTemperatureOverrideEnabled: (enabled: boolean) => Promise; } const UserContext = createContext(undefined); @@ -57,6 +58,41 @@ export function UserProvider({ console.error("Error fetching current user:", error); } }; + const updateUserTemperatureOverrideEnabled = async (enabled: boolean) => { + try { + setUpToDateUser((prevUser) => { + if (prevUser) { + return { + ...prevUser, + preferences: { + ...prevUser.preferences, + temperature_override_enabled: enabled, + }, + }; + } + return prevUser; + }); + + const response = await fetch( + `/api/temperature-override-enabled?temperature_override_enabled=${enabled}`, + { + method: "PATCH", + headers: { + "Content-Type": "application/json", + }, + } + ); + + if (!response.ok) { + await refreshUser(); + throw new Error("Failed to update user temperature override setting"); + } + } catch (error) { + console.error("Error updating user temperature override setting:", error); + throw error; + } + }; + const updateUserShortcuts = async (enabled: boolean) => { try { setUpToDateUser((prevUser) => { @@ -184,6 +220,7 @@ export function UserProvider({ refreshUser, updateUserAutoScroll, updateUserShortcuts, + updateUserTemperatureOverrideEnabled, toggleAssistantPinnedStatus, isAdmin: upToDateUser?.role === UserRole.ADMIN, // Curator status applies for either global or basic curator diff --git a/web/src/lib/hooks.ts b/web/src/lib/hooks.ts index c179f9d7a..523f50728 100644 --- a/web/src/lib/hooks.ts +++ b/web/src/lib/hooks.ts @@ -10,7 +10,7 @@ import { } from "@/lib/types"; import useSWR, { mutate, useSWRConfig } from "swr"; import { errorHandlingFetcher } from "./fetcher"; -import { useContext, useEffect, useState } from "react"; +import { useContext, useEffect, useMemo, useState } from "react"; import { DateRangePickerValue } from "@/app/ee/admin/performance/DateRangeSelector"; import { Filters, SourceMetadata } from "./search/interfaces"; import { @@ -28,6 +28,8 @@ import { isAnthropic } from "@/app/admin/configuration/llm/interfaces"; import { getSourceMetadata } from "./sources"; import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "./constants"; import { useUser } from "@/components/user/UserProvider"; +import { SEARCH_TOOL_ID } from "@/app/chat/tools/constants"; +import { updateTemperatureOverrideForChatSession } from "@/app/chat/lib"; const CREDENTIAL_URL = "/api/manage/admin/credential"; @@ -360,12 +362,13 @@ export interface LlmOverride { export interface LlmOverrideManager { llmOverride: LlmOverride; updateLLMOverride: (newOverride: LlmOverride) => void; - temperature: number | null; - updateTemperature: (temperature: number | null) => void; + temperature: number; + updateTemperature: (temperature: number) => void; updateModelOverrideForChatSession: (chatSession?: ChatSession) => void; imageFilesPresent: boolean; updateImageFilesPresent: (present: boolean) => void; liveAssistant: Persona | null; + maxTemperature: number; } // Things to test @@ -395,6 +398,18 @@ Changes take place as If we have a live assistant, we should use that model override Relevant test: `llm_ordering.spec.ts`. + +Temperature override is set as follows: +- For existing chat sessions: + - If the user has previously overridden the temperature for a specific chat session, + that value is persisted and used when the user returns to that chat. + - This persistence applies even if the temperature was set before sending the first message in the chat. +- For new chat sessions: + - If the search tool is available, the default temperature is set to 0. + - If the search tool is not available, the default temperature is set to 0.5. + +This approach ensures that user preferences are maintained for existing chats while +providing appropriate defaults for new conversations based on the available tools. */ export function useLlmOverride( @@ -407,11 +422,6 @@ export function useLlmOverride( const [chatSession, setChatSession] = useState(null); const llmOverrideUpdate = () => { - if (!chatSession && currentChatSession) { - setChatSession(currentChatSession || null); - return; - } - if (liveAssistant?.llm_model_version_override) { setLlmOverride( getValidLlmOverride(liveAssistant.llm_model_version_override) @@ -499,24 +509,68 @@ export function useLlmOverride( } }; - const [temperature, setTemperature] = useState(0); - - useEffect(() => { + const [temperature, setTemperature] = useState(() => { llmOverrideUpdate(); - }, [liveAssistant, currentChatSession]); + + if (currentChatSession?.current_temperature_override != null) { + return Math.min( + currentChatSession.current_temperature_override, + isAnthropic(llmOverride.provider, llmOverride.modelName) ? 1.0 : 2.0 + ); + } else if ( + liveAssistant?.tools.some((tool) => tool.name === SEARCH_TOOL_ID) + ) { + return 0; + } + return 0.5; + }); + + const maxTemperature = useMemo(() => { + return isAnthropic(llmOverride.provider, llmOverride.modelName) ? 1.0 : 2.0; + }, [llmOverride]); useEffect(() => { if (isAnthropic(llmOverride.provider, llmOverride.modelName)) { - setTemperature((prevTemp) => Math.min(prevTemp ?? 0, 1.0)); + const newTemperature = Math.min(temperature, 1.0); + setTemperature(newTemperature); + if (chatSession?.id) { + updateTemperatureOverrideForChatSession(chatSession.id, newTemperature); + } } }, [llmOverride]); - const updateTemperature = (temperature: number | null) => { + useEffect(() => { + if (!chatSession && currentChatSession) { + setChatSession(currentChatSession || null); + if (temperature) { + updateTemperatureOverrideForChatSession( + currentChatSession.id, + temperature + ); + } + return; + } + + if (currentChatSession?.current_temperature_override) { + setTemperature(currentChatSession.current_temperature_override); + } else if ( + liveAssistant?.tools.some((tool) => tool.name === SEARCH_TOOL_ID) + ) { + setTemperature(0); + } else { + setTemperature(0.5); + } + }, [liveAssistant, currentChatSession]); + + const updateTemperature = (temperature: number) => { if (isAnthropic(llmOverride.provider, llmOverride.modelName)) { - setTemperature((prevTemp) => Math.min(temperature ?? 0, 1.0)); + setTemperature((prevTemp) => Math.min(temperature, 1.0)); } else { setTemperature(temperature); } + if (chatSession) { + updateTemperatureOverrideForChatSession(chatSession.id, temperature); + } }; return { @@ -528,6 +582,7 @@ export function useLlmOverride( imageFilesPresent, updateImageFilesPresent, liveAssistant: liveAssistant ?? null, + maxTemperature, }; } diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index 414b23b46..4f2ac5b40 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -12,6 +12,7 @@ interface UserPreferences { recent_assistants: number[]; auto_scroll: boolean | null; shortcut_enabled: boolean; + temperature_override_enabled: boolean; } export enum UserRole {