mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-02 16:00:34 +02:00
Add user specific chat session temperature (#3867)
* add user specific chat session temperature * kbetter typing * update
This commit is contained in:
parent
6a7e2a8036
commit
11da0d9889
@ -0,0 +1,36 @@
|
||||
"""add chat session specific temperature override
|
||||
|
||||
Revision ID: 2f80c6a2550f
|
||||
Revises: 33ea50e88f24
|
||||
Create Date: 2025-01-31 10:30:27.289646
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2f80c6a2550f"
|
||||
down_revision = "33ea50e88f24"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_session", sa.Column("temperature_override", sa.Float(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"temperature_override_enabled",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_session", "temperature_override")
|
||||
op.drop_column("user", "temperature_override_enabled")
|
@ -150,6 +150,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
|
||||
# if specified, controls the assistants that are shown to the user + their order
|
||||
# if not specified, all assistants are shown
|
||||
temperature_override_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
auto_scroll: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
shortcut_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
chosen_assistants: Mapped[list[int] | None] = mapped_column(
|
||||
@ -1115,6 +1116,10 @@ class ChatSession(Base):
|
||||
llm_override: Mapped[LLMOverride | None] = mapped_column(
|
||||
PydanticType(LLMOverride), nullable=True
|
||||
)
|
||||
|
||||
# The latest temperature override specified by the user
|
||||
temperature_override: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
|
||||
prompt_override: Mapped[PromptOverride | None] = mapped_column(
|
||||
PydanticType(PromptOverride), nullable=True
|
||||
)
|
||||
|
@ -175,7 +175,6 @@ class EmbeddingModel:
|
||||
if self.callback.should_stop():
|
||||
raise RuntimeError("_batch_encode_texts detected stop signal")
|
||||
|
||||
logger.debug(f"Encoding batch {batch_idx} of {len(text_batches)}")
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=text_batch,
|
||||
@ -191,7 +190,15 @@ class EmbeddingModel:
|
||||
api_url=self.api_url,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
response = self._make_model_server_request(embed_request)
|
||||
end_time = time.time()
|
||||
|
||||
processing_time = end_time - start_time
|
||||
logger.info(
|
||||
f"Batch {batch_idx} processing time: {processing_time:.2f} seconds"
|
||||
)
|
||||
|
||||
return batch_idx, response.embeddings
|
||||
|
||||
# only multi thread if:
|
||||
|
@ -48,6 +48,7 @@ class UserPreferences(BaseModel):
|
||||
auto_scroll: bool | None = None
|
||||
pinned_assistants: list[int] | None = None
|
||||
shortcut_enabled: bool | None = None
|
||||
temperature_override_enabled: bool | None = None
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
@ -91,6 +92,7 @@ class UserInfo(BaseModel):
|
||||
hidden_assistants=user.hidden_assistants,
|
||||
pinned_assistants=user.pinned_assistants,
|
||||
visible_assistants=user.visible_assistants,
|
||||
temperature_override_enabled=user.temperature_override_enabled,
|
||||
)
|
||||
),
|
||||
organization_name=organization_name,
|
||||
|
@ -568,6 +568,32 @@ def verify_user_logged_in(
|
||||
"""APIs to adjust user preferences"""
|
||||
|
||||
|
||||
@router.patch("/temperature-override-enabled")
|
||||
def update_user_temperature_override_enabled(
|
||||
temperature_override_enabled: bool,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if user is None:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
store = get_kv_store()
|
||||
no_auth_user = fetch_no_auth_user(store)
|
||||
no_auth_user.preferences.temperature_override_enabled = (
|
||||
temperature_override_enabled
|
||||
)
|
||||
set_no_auth_user_preferences(store, no_auth_user.preferences)
|
||||
return
|
||||
else:
|
||||
raise RuntimeError("This should never happen")
|
||||
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user.id) # type: ignore
|
||||
.values(temperature_override_enabled=temperature_override_enabled)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
class ChosenDefaultModelRequest(BaseModel):
|
||||
default_model: str | None = None
|
||||
|
||||
|
@ -77,6 +77,7 @@ from onyx.server.query_and_chat.models import LLMOverride
|
||||
from onyx.server.query_and_chat.models import PromptOverride
|
||||
from onyx.server.query_and_chat.models import RenameChatSessionResponse
|
||||
from onyx.server.query_and_chat.models import SearchFeedbackRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
|
||||
from onyx.server.query_and_chat.token_limit import check_token_rate_limits
|
||||
from onyx.utils.headers import get_custom_tool_additional_request_headers
|
||||
@ -114,12 +115,52 @@ def get_user_chat_sessions(
|
||||
shared_status=chat.shared_status,
|
||||
folder_id=chat.folder_id,
|
||||
current_alternate_model=chat.current_alternate_model,
|
||||
current_temperature_override=chat.temperature_override,
|
||||
)
|
||||
for chat in chat_sessions
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@router.put("/update-chat-session-temperature")
|
||||
def update_chat_session_temperature(
|
||||
update_thread_req: UpdateChatSessionTemperatureRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=update_thread_req.chat_session_id,
|
||||
user_id=user.id if user is not None else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Validate temperature_override
|
||||
if update_thread_req.temperature_override is not None:
|
||||
if (
|
||||
update_thread_req.temperature_override < 0
|
||||
or update_thread_req.temperature_override > 2
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Temperature must be between 0 and 2"
|
||||
)
|
||||
|
||||
# Additional check for Anthropic models
|
||||
if (
|
||||
chat_session.current_alternate_model
|
||||
and "anthropic" in chat_session.current_alternate_model.lower()
|
||||
):
|
||||
if update_thread_req.temperature_override > 1:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Temperature for Anthropic models must be between 0 and 1",
|
||||
)
|
||||
|
||||
chat_session.temperature_override = update_thread_req.temperature_override
|
||||
|
||||
db_session.add(chat_session)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@router.put("/update-chat-session-model")
|
||||
def update_chat_session_model(
|
||||
update_thread_req: UpdateChatSessionThreadRequest,
|
||||
@ -190,6 +231,7 @@ def get_chat_session(
|
||||
],
|
||||
time_created=chat_session.time_created,
|
||||
shared_status=chat_session.shared_status,
|
||||
current_temperature_override=chat_session.temperature_override,
|
||||
)
|
||||
|
||||
|
||||
|
@ -42,6 +42,11 @@ class UpdateChatSessionThreadRequest(BaseModel):
|
||||
new_alternate_model: str
|
||||
|
||||
|
||||
class UpdateChatSessionTemperatureRequest(BaseModel):
|
||||
chat_session_id: UUID
|
||||
temperature_override: float
|
||||
|
||||
|
||||
class ChatSessionCreationRequest(BaseModel):
|
||||
# If not specified, use Onyx default persona
|
||||
persona_id: int = 0
|
||||
@ -108,6 +113,10 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
llm_override: LLMOverride | None = None
|
||||
prompt_override: PromptOverride | None = None
|
||||
|
||||
# Allows the caller to override the temperature for the chat session
|
||||
# this does persist in the chat thread details
|
||||
temperature_override: float | None = None
|
||||
|
||||
# allow user to specify an alternate assistnat
|
||||
alternate_assistant_id: int | None = None
|
||||
|
||||
@ -168,6 +177,7 @@ class ChatSessionDetails(BaseModel):
|
||||
shared_status: ChatSessionSharedStatus
|
||||
folder_id: int | None = None
|
||||
current_alternate_model: str | None = None
|
||||
current_temperature_override: float | None = None
|
||||
|
||||
|
||||
class ChatSessionsResponse(BaseModel):
|
||||
@ -231,6 +241,7 @@ class ChatSessionDetailResponse(BaseModel):
|
||||
time_created: datetime
|
||||
shared_status: ChatSessionSharedStatus
|
||||
current_alternate_model: str | None
|
||||
current_temperature_override: float | None
|
||||
|
||||
|
||||
# This one is not used anymore
|
||||
|
137
web/package-lock.json
generated
137
web/package-lock.json
generated
@ -25,6 +25,7 @@
|
||||
"@radix-ui/react-scroll-area": "^1.2.2",
|
||||
"@radix-ui/react-select": "^2.1.2",
|
||||
"@radix-ui/react-separator": "^1.1.0",
|
||||
"@radix-ui/react-slider": "^1.2.2",
|
||||
"@radix-ui/react-slot": "^1.1.0",
|
||||
"@radix-ui/react-switch": "^1.1.1",
|
||||
"@radix-ui/react-tabs": "^1.1.1",
|
||||
@ -4963,6 +4964,142 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-slider": {
|
||||
"version": "1.2.2",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-slider/-/react-slider-1.2.2.tgz",
|
||||
"integrity": "sha512-sNlU06ii1/ZcbHf8I9En54ZPW0Vil/yPVg4vQMcFNjrIx51jsHbFl1HYHQvCIWJSr1q0ZmA+iIs/ZTv8h7HHSA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/number": "1.1.0",
|
||||
"@radix-ui/primitive": "1.1.1",
|
||||
"@radix-ui/react-collection": "1.1.1",
|
||||
"@radix-ui/react-compose-refs": "1.1.1",
|
||||
"@radix-ui/react-context": "1.1.1",
|
||||
"@radix-ui/react-direction": "1.1.0",
|
||||
"@radix-ui/react-primitive": "2.0.1",
|
||||
"@radix-ui/react-use-controllable-state": "1.1.0",
|
||||
"@radix-ui/react-use-layout-effect": "1.1.0",
|
||||
"@radix-ui/react-use-previous": "1.1.0",
|
||||
"@radix-ui/react-use-size": "1.1.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-slider/node_modules/@radix-ui/primitive": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/primitive/-/primitive-1.1.1.tgz",
|
||||
"integrity": "sha512-SJ31y+Q/zAyShtXJc8x83i9TYdbAfHZ++tUZnvjJJqFjzsdUnKsxPL6IEtBlxKkU7yzer//GQtZSV4GbldL3YA==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@radix-ui/react-slider/node_modules/@radix-ui/react-collection": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-collection/-/react-collection-1.1.1.tgz",
|
||||
"integrity": "sha512-LwT3pSho9Dljg+wY2KN2mrrh6y3qELfftINERIzBUO9e0N+t0oMTyn3k9iv+ZqgrwGkRnLpNJrsMv9BZlt2yuA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-compose-refs": "1.1.1",
|
||||
"@radix-ui/react-context": "1.1.1",
|
||||
"@radix-ui/react-primitive": "2.0.1",
|
||||
"@radix-ui/react-slot": "1.1.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-slider/node_modules/@radix-ui/react-compose-refs": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-compose-refs/-/react-compose-refs-1.1.1.tgz",
|
||||
"integrity": "sha512-Y9VzoRDSJtgFMUCoiZBDVo084VQ5hfpXxVE+NgkdNsjiDBByiImMZKKhxMwCbdHvhlENG6a833CbFkOQvTricw==",
|
||||
"license": "MIT",
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-slider/node_modules/@radix-ui/react-context": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-context/-/react-context-1.1.1.tgz",
|
||||
"integrity": "sha512-UASk9zi+crv9WteK/NU4PLvOoL3OuE6BWVKNF6hPRBtYBDXQ2u5iu3O59zUlJiTVvkyuycnqrztsHVJwcK9K+Q==",
|
||||
"license": "MIT",
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-slider/node_modules/@radix-ui/react-primitive": {
|
||||
"version": "2.0.1",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.0.1.tgz",
|
||||
"integrity": "sha512-sHCWTtxwNn3L3fH8qAfnF3WbUZycW93SM1j3NFDzXBiz8D6F5UTTy8G1+WFEaiCdvCVRJWj6N2R4Xq6HdiHmDg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-slot": "1.1.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-slider/node_modules/@radix-ui/react-slot": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.1.1.tgz",
|
||||
"integrity": "sha512-RApLLOcINYJA+dMVbOju7MYv1Mb2EBp2nH4HdDzXTSyaR5optlm6Otrz1euW3HbdOR8UmmFK06TD+A9frYWv+g==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-compose-refs": "1.1.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-slot": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.1.0.tgz",
|
||||
|
@ -28,6 +28,7 @@
|
||||
"@radix-ui/react-scroll-area": "^1.2.2",
|
||||
"@radix-ui/react-select": "^2.1.2",
|
||||
"@radix-ui/react-separator": "^1.1.0",
|
||||
"@radix-ui/react-slider": "^1.2.2",
|
||||
"@radix-ui/react-slot": "^1.1.0",
|
||||
"@radix-ui/react-switch": "^1.1.1",
|
||||
"@radix-ui/react-tabs": "^1.1.1",
|
||||
|
@ -404,9 +404,6 @@ export function ChatPage({
|
||||
filterManager.setSelectedTags([]);
|
||||
filterManager.setTimeRange(null);
|
||||
|
||||
// reset LLM overrides (based on chat session!)
|
||||
llmOverrideManager.updateTemperature(null);
|
||||
|
||||
// remove uploaded files
|
||||
setCurrentMessageFiles([]);
|
||||
|
||||
@ -449,6 +446,7 @@ export function ChatPage({
|
||||
);
|
||||
|
||||
const chatSession = (await response.json()) as BackendChatSession;
|
||||
|
||||
setSelectedAssistantFromId(chatSession.persona_id);
|
||||
|
||||
const newMessageMap = processRawChatHistory(chatSession.messages);
|
||||
|
@ -1,4 +1,4 @@
|
||||
import React, { useState } from "react";
|
||||
import React, { useState, useEffect } from "react";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
@ -26,6 +26,9 @@ import {
|
||||
} from "@/components/ui/tooltip";
|
||||
import { FiAlertTriangle } from "react-icons/fi";
|
||||
|
||||
import { Slider } from "@/components/ui/slider";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
|
||||
interface LLMPopoverProps {
|
||||
llmProviders: LLMProviderDescriptor[];
|
||||
llmOverrideManager: LlmOverrideManager;
|
||||
@ -40,6 +43,7 @@ export default function LLMPopover({
|
||||
currentAssistant,
|
||||
}: LLMPopoverProps) {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const { user } = useUser();
|
||||
const { llmOverride, updateLLMOverride } = llmOverrideManager;
|
||||
const currentLlm = llmOverride.modelName;
|
||||
|
||||
@ -88,6 +92,22 @@ export default function LLMPopover({
|
||||
? getDisplayNameForModel(defaultModelName)
|
||||
: null;
|
||||
|
||||
const [localTemperature, setLocalTemperature] = useState(
|
||||
llmOverrideManager.temperature ?? 0.5
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
setLocalTemperature(llmOverrideManager.temperature ?? 0.5);
|
||||
}, [llmOverrideManager.temperature]);
|
||||
|
||||
const handleTemperatureChange = (value: number[]) => {
|
||||
setLocalTemperature(value[0]);
|
||||
};
|
||||
|
||||
const handleTemperatureChangeComplete = (value: number[]) => {
|
||||
llmOverrideManager.updateTemperature(value[0]);
|
||||
};
|
||||
|
||||
return (
|
||||
<Popover open={isOpen} onOpenChange={setIsOpen}>
|
||||
<PopoverTrigger asChild>
|
||||
@ -118,9 +138,9 @@ export default function LLMPopover({
|
||||
</PopoverTrigger>
|
||||
<PopoverContent
|
||||
align="start"
|
||||
className="w-64 p-1 bg-background border border-gray-200 rounded-md shadow-lg"
|
||||
className="w-64 p-1 bg-background border border-gray-200 rounded-md shadow-lg flex flex-col"
|
||||
>
|
||||
<div className="max-h-[300px] overflow-y-auto">
|
||||
<div className="flex-grow max-h-[300px] default-scrollbar overflow-y-auto">
|
||||
{llmOptions.map(({ name, icon, value }, index) => {
|
||||
if (!requiresImageGeneration || checkLLMSupportsImageInput(name)) {
|
||||
return (
|
||||
@ -171,6 +191,25 @@ export default function LLMPopover({
|
||||
return null;
|
||||
})}
|
||||
</div>
|
||||
{user?.preferences?.temperature_override_enabled && (
|
||||
<div className="mt-2 pt-2 border-t border-gray-200">
|
||||
<div className="w-full px-3 py-2">
|
||||
<Slider
|
||||
value={[localTemperature]}
|
||||
max={llmOverrideManager.maxTemperature}
|
||||
min={0}
|
||||
step={0.01}
|
||||
onValueChange={handleTemperatureChange}
|
||||
onValueCommit={handleTemperatureChangeComplete}
|
||||
className="w-full"
|
||||
/>
|
||||
<div className="flex justify-between text-xs text-gray-500 mt-2">
|
||||
<span>Temperature (creativity)</span>
|
||||
<span>{localTemperature.toFixed(1)}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
|
@ -68,6 +68,7 @@ export interface ChatSession {
|
||||
shared_status: ChatSessionSharedStatus;
|
||||
folder_id: number | null;
|
||||
current_alternate_model: string;
|
||||
current_temperature_override: number | null;
|
||||
}
|
||||
|
||||
export interface SearchSession {
|
||||
@ -107,6 +108,7 @@ export interface BackendChatSession {
|
||||
messages: BackendMessage[];
|
||||
time_created: string;
|
||||
shared_status: ChatSessionSharedStatus;
|
||||
current_temperature_override: number | null;
|
||||
current_alternate_model?: string;
|
||||
}
|
||||
|
||||
|
@ -75,6 +75,23 @@ export async function updateModelOverrideForChatSession(
|
||||
return response;
|
||||
}
|
||||
|
||||
export async function updateTemperatureOverrideForChatSession(
|
||||
chatSessionId: string,
|
||||
newTemperature: number
|
||||
) {
|
||||
const response = await fetch("/api/chat/update-chat-session-temperature", {
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
chat_session_id: chatSessionId,
|
||||
temperature_override: newTemperature,
|
||||
}),
|
||||
});
|
||||
return response;
|
||||
}
|
||||
|
||||
export async function createChatSession(
|
||||
personaId: number,
|
||||
description: string | null
|
||||
|
@ -30,8 +30,13 @@ export function UserSettingsModal({
|
||||
defaultModel: string | null;
|
||||
}) {
|
||||
const { inputPrompts, refreshInputPrompts } = useChatContext();
|
||||
const { refreshUser, user, updateUserAutoScroll, updateUserShortcuts } =
|
||||
useUser();
|
||||
const {
|
||||
refreshUser,
|
||||
user,
|
||||
updateUserAutoScroll,
|
||||
updateUserShortcuts,
|
||||
updateUserTemperatureOverrideEnabled,
|
||||
} = useUser();
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const messageRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
@ -179,6 +184,16 @@ export function UserSettingsModal({
|
||||
/>
|
||||
<Label className="text-sm">Enable Prompt Shortcuts</Label>
|
||||
</div>
|
||||
<div className="flex items-center gap-x-2">
|
||||
<Switch
|
||||
size="sm"
|
||||
checked={user?.preferences?.temperature_override_enabled}
|
||||
onCheckedChange={(checked) => {
|
||||
updateUserTemperatureOverrideEnabled(checked);
|
||||
}}
|
||||
/>
|
||||
<Label className="text-sm">Enable Temperature Override</Label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
@ -103,7 +103,7 @@ export function Modal({
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
<div className="flex-shrink-0">
|
||||
<div className="items-start flex-shrink-0">
|
||||
{title && (
|
||||
<>
|
||||
<div className="flex">
|
||||
|
28
web/src/components/ui/slider.tsx
Normal file
28
web/src/components/ui/slider.tsx
Normal file
@ -0,0 +1,28 @@
|
||||
"use client";
|
||||
|
||||
import * as React from "react";
|
||||
import * as SliderPrimitive from "@radix-ui/react-slider";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const Slider = React.forwardRef<
|
||||
React.ElementRef<typeof SliderPrimitive.Root>,
|
||||
React.ComponentPropsWithoutRef<typeof SliderPrimitive.Root>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<SliderPrimitive.Root
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"relative flex w-full touch-none select-none items-center",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
<SliderPrimitive.Track className="relative h-2 w-full grow overflow-hidden rounded-full bg-neutral-100 dark:bg-neutral-800">
|
||||
<SliderPrimitive.Range className="absolute h-full bg-neutral-900 dark:bg-neutral-50" />
|
||||
</SliderPrimitive.Track>
|
||||
<SliderPrimitive.Thumb className="block h-3 w-3 rounded-full border border-neutral-900 bg-white ring-offset-white transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-neutral-950 focus-visible:ring-offset disabled:pointer-events-none disabled:opacity-50 dark:border-neutral-50 dark:bg-neutral-950 dark:ring-offset-neutral-950 dark:focus-visible:ring-neutral-300" />
|
||||
</SliderPrimitive.Root>
|
||||
));
|
||||
Slider.displayName = SliderPrimitive.Root.displayName;
|
||||
|
||||
export { Slider };
|
@ -18,6 +18,7 @@ interface UserContextType {
|
||||
assistantId: number,
|
||||
isPinned: boolean
|
||||
) => Promise<boolean>;
|
||||
updateUserTemperatureOverrideEnabled: (enabled: boolean) => Promise<void>;
|
||||
}
|
||||
|
||||
const UserContext = createContext<UserContextType | undefined>(undefined);
|
||||
@ -57,6 +58,41 @@ export function UserProvider({
|
||||
console.error("Error fetching current user:", error);
|
||||
}
|
||||
};
|
||||
const updateUserTemperatureOverrideEnabled = async (enabled: boolean) => {
|
||||
try {
|
||||
setUpToDateUser((prevUser) => {
|
||||
if (prevUser) {
|
||||
return {
|
||||
...prevUser,
|
||||
preferences: {
|
||||
...prevUser.preferences,
|
||||
temperature_override_enabled: enabled,
|
||||
},
|
||||
};
|
||||
}
|
||||
return prevUser;
|
||||
});
|
||||
|
||||
const response = await fetch(
|
||||
`/api/temperature-override-enabled?temperature_override_enabled=${enabled}`,
|
||||
{
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
await refreshUser();
|
||||
throw new Error("Failed to update user temperature override setting");
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error updating user temperature override setting:", error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
const updateUserShortcuts = async (enabled: boolean) => {
|
||||
try {
|
||||
setUpToDateUser((prevUser) => {
|
||||
@ -184,6 +220,7 @@ export function UserProvider({
|
||||
refreshUser,
|
||||
updateUserAutoScroll,
|
||||
updateUserShortcuts,
|
||||
updateUserTemperatureOverrideEnabled,
|
||||
toggleAssistantPinnedStatus,
|
||||
isAdmin: upToDateUser?.role === UserRole.ADMIN,
|
||||
// Curator status applies for either global or basic curator
|
||||
|
@ -10,7 +10,7 @@ import {
|
||||
} from "@/lib/types";
|
||||
import useSWR, { mutate, useSWRConfig } from "swr";
|
||||
import { errorHandlingFetcher } from "./fetcher";
|
||||
import { useContext, useEffect, useState } from "react";
|
||||
import { useContext, useEffect, useMemo, useState } from "react";
|
||||
import { DateRangePickerValue } from "@/app/ee/admin/performance/DateRangeSelector";
|
||||
import { Filters, SourceMetadata } from "./search/interfaces";
|
||||
import {
|
||||
@ -28,6 +28,8 @@ import { isAnthropic } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { getSourceMetadata } from "./sources";
|
||||
import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "./constants";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
import { SEARCH_TOOL_ID } from "@/app/chat/tools/constants";
|
||||
import { updateTemperatureOverrideForChatSession } from "@/app/chat/lib";
|
||||
|
||||
const CREDENTIAL_URL = "/api/manage/admin/credential";
|
||||
|
||||
@ -360,12 +362,13 @@ export interface LlmOverride {
|
||||
export interface LlmOverrideManager {
|
||||
llmOverride: LlmOverride;
|
||||
updateLLMOverride: (newOverride: LlmOverride) => void;
|
||||
temperature: number | null;
|
||||
updateTemperature: (temperature: number | null) => void;
|
||||
temperature: number;
|
||||
updateTemperature: (temperature: number) => void;
|
||||
updateModelOverrideForChatSession: (chatSession?: ChatSession) => void;
|
||||
imageFilesPresent: boolean;
|
||||
updateImageFilesPresent: (present: boolean) => void;
|
||||
liveAssistant: Persona | null;
|
||||
maxTemperature: number;
|
||||
}
|
||||
|
||||
// Things to test
|
||||
@ -395,6 +398,18 @@ Changes take place as
|
||||
If we have a live assistant, we should use that model override
|
||||
|
||||
Relevant test: `llm_ordering.spec.ts`.
|
||||
|
||||
Temperature override is set as follows:
|
||||
- For existing chat sessions:
|
||||
- If the user has previously overridden the temperature for a specific chat session,
|
||||
that value is persisted and used when the user returns to that chat.
|
||||
- This persistence applies even if the temperature was set before sending the first message in the chat.
|
||||
- For new chat sessions:
|
||||
- If the search tool is available, the default temperature is set to 0.
|
||||
- If the search tool is not available, the default temperature is set to 0.5.
|
||||
|
||||
This approach ensures that user preferences are maintained for existing chats while
|
||||
providing appropriate defaults for new conversations based on the available tools.
|
||||
*/
|
||||
|
||||
export function useLlmOverride(
|
||||
@ -407,11 +422,6 @@ export function useLlmOverride(
|
||||
const [chatSession, setChatSession] = useState<ChatSession | null>(null);
|
||||
|
||||
const llmOverrideUpdate = () => {
|
||||
if (!chatSession && currentChatSession) {
|
||||
setChatSession(currentChatSession || null);
|
||||
return;
|
||||
}
|
||||
|
||||
if (liveAssistant?.llm_model_version_override) {
|
||||
setLlmOverride(
|
||||
getValidLlmOverride(liveAssistant.llm_model_version_override)
|
||||
@ -499,24 +509,68 @@ export function useLlmOverride(
|
||||
}
|
||||
};
|
||||
|
||||
const [temperature, setTemperature] = useState<number | null>(0);
|
||||
|
||||
useEffect(() => {
|
||||
const [temperature, setTemperature] = useState<number>(() => {
|
||||
llmOverrideUpdate();
|
||||
}, [liveAssistant, currentChatSession]);
|
||||
|
||||
if (currentChatSession?.current_temperature_override != null) {
|
||||
return Math.min(
|
||||
currentChatSession.current_temperature_override,
|
||||
isAnthropic(llmOverride.provider, llmOverride.modelName) ? 1.0 : 2.0
|
||||
);
|
||||
} else if (
|
||||
liveAssistant?.tools.some((tool) => tool.name === SEARCH_TOOL_ID)
|
||||
) {
|
||||
return 0;
|
||||
}
|
||||
return 0.5;
|
||||
});
|
||||
|
||||
const maxTemperature = useMemo(() => {
|
||||
return isAnthropic(llmOverride.provider, llmOverride.modelName) ? 1.0 : 2.0;
|
||||
}, [llmOverride]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) {
|
||||
setTemperature((prevTemp) => Math.min(prevTemp ?? 0, 1.0));
|
||||
const newTemperature = Math.min(temperature, 1.0);
|
||||
setTemperature(newTemperature);
|
||||
if (chatSession?.id) {
|
||||
updateTemperatureOverrideForChatSession(chatSession.id, newTemperature);
|
||||
}
|
||||
}
|
||||
}, [llmOverride]);
|
||||
|
||||
const updateTemperature = (temperature: number | null) => {
|
||||
useEffect(() => {
|
||||
if (!chatSession && currentChatSession) {
|
||||
setChatSession(currentChatSession || null);
|
||||
if (temperature) {
|
||||
updateTemperatureOverrideForChatSession(
|
||||
currentChatSession.id,
|
||||
temperature
|
||||
);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (currentChatSession?.current_temperature_override) {
|
||||
setTemperature(currentChatSession.current_temperature_override);
|
||||
} else if (
|
||||
liveAssistant?.tools.some((tool) => tool.name === SEARCH_TOOL_ID)
|
||||
) {
|
||||
setTemperature(0);
|
||||
} else {
|
||||
setTemperature(0.5);
|
||||
}
|
||||
}, [liveAssistant, currentChatSession]);
|
||||
|
||||
const updateTemperature = (temperature: number) => {
|
||||
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) {
|
||||
setTemperature((prevTemp) => Math.min(temperature ?? 0, 1.0));
|
||||
setTemperature((prevTemp) => Math.min(temperature, 1.0));
|
||||
} else {
|
||||
setTemperature(temperature);
|
||||
}
|
||||
if (chatSession) {
|
||||
updateTemperatureOverrideForChatSession(chatSession.id, temperature);
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
@ -528,6 +582,7 @@ export function useLlmOverride(
|
||||
imageFilesPresent,
|
||||
updateImageFilesPresent,
|
||||
liveAssistant: liveAssistant ?? null,
|
||||
maxTemperature,
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -12,6 +12,7 @@ interface UserPreferences {
|
||||
recent_assistants: number[];
|
||||
auto_scroll: boolean | null;
|
||||
shortcut_enabled: boolean;
|
||||
temperature_override_enabled: boolean;
|
||||
}
|
||||
|
||||
export enum UserRole {
|
||||
|
Loading…
x
Reference in New Issue
Block a user