Add user specific model defaults (#2043)

This commit is contained in:
pablodanswer
2024-08-10 14:37:33 -07:00
committed by GitHub
parent a4caf66a35
commit 7d5cfd2fa3
15 changed files with 342 additions and 33 deletions

View File

@@ -0,0 +1,24 @@
"""Added model defaults for users
Revision ID: 7477a5f5d728
Revises: 213fd978c6d8
Create Date: 2024-08-04 19:00:04.512634
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7477a5f5d728"
down_revision = "213fd978c6d8"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column("user", sa.Column("default_model", sa.Text(), nullable=True))
def downgrade() -> None:
op.drop_column("user", "default_model")

View File

@@ -23,7 +23,7 @@ def load_no_auth_user_preferences(store: DynamicConfigStore) -> UserPreferences:
)
return UserPreferences(**preferences_data)
except ConfigNotFoundError:
return UserPreferences(chosen_assistants=None)
return UserPreferences(chosen_assistants=None, default_model=None)
def fetch_no_auth_user(store: DynamicConfigStore) -> UserInfo:

View File

@@ -125,6 +125,10 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
TIMESTAMPAware(timezone=True), nullable=True
)
default_model: Mapped[str] = mapped_column(Text, nullable=True)
# organized in typical structured fashion
# formatted as `displayName__provider__modelName`
# relationships
credentials: Mapped[list["Credential"]] = relationship(
"Credential", back_populates="user", lazy="joined"

View File

@@ -39,6 +39,7 @@ class AuthTypeResponse(BaseModel):
class UserPreferences(BaseModel):
chosen_assistants: list[int] | None
default_model: str | None
class UserInfo(BaseModel):
@@ -67,7 +68,12 @@ class UserInfo(BaseModel):
is_superuser=user.is_superuser,
is_verified=user.is_verified,
role=user.role,
preferences=(UserPreferences(chosen_assistants=user.chosen_assistants)),
preferences=(
UserPreferences(
chosen_assistants=user.chosen_assistants,
default_model=user.default_model,
)
),
oidc_expiry=user.oidc_expiry,
current_token_created_at=current_token_created_at,
current_token_expiry_length=expiry_length,

View File

@@ -316,6 +316,34 @@ def verify_user_logged_in(
"""APIs to adjust user preferences"""
class ChosenDefaultModelRequest(BaseModel):
default_model: str | None
@router.patch("/user/default-model")
def update_user_default_model(
request: ChosenDefaultModelRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
if user is None:
if AUTH_TYPE == AuthType.DISABLED:
store = get_dynamic_config_store()
no_auth_user = fetch_no_auth_user(store)
no_auth_user.preferences.default_model = request.default_model
set_no_auth_user_preferences(store, no_auth_user.preferences)
return
else:
raise RuntimeError("This should never happen")
db_session.execute(
update(User)
.where(User.id == user.id) # type: ignore
.values(default_model=request.default_model)
)
db_session.commit()
class ChosenAssistantsRequest(BaseModel):
chosen_assistants: list[int]

View File

@@ -36,7 +36,8 @@ class UserGroup(BaseModel):
is_verified=user.is_verified,
role=user.role,
preferences=UserPreferences(
chosen_assistants=user.chosen_assistants
default_model=user.default_model,
chosen_assistants=user.chosen_assistants,
),
)
for user in user_group_model.users

View File

@@ -28,7 +28,7 @@ import { GroupsIcon, PaintingIcon, SwapIcon } from "@/components/icons/icons";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences";
import { useUserGroups } from "@/lib/hooks";
import { checkLLMSupportsImageInput } from "@/lib/llm/utils";
import { checkLLMSupportsImageInput, destructureValue } from "@/lib/llm/utils";
import { ToolSnapshot } from "@/lib/tools/interfaces";
import { checkUserIsNoAuthUser } from "@/lib/user";
import {
@@ -539,13 +539,15 @@ export function AssistantEditor({
</TooltipProvider>
</div>
<p className="my-1 text-text-600">
You assistant will use your system default (currently{" "}
{defaultModelName}) unless otherwise specified below.
Your assistant will use the user&apos;s set default unless
otherwise specified below.
{user?.preferences.default_model &&
` Your current (user-specific) default model is ${getDisplayNameForModel(destructureValue(user?.preferences?.default_model!).modelName)}`}
</p>
<div className="mb-2 flex items-starts">
<div className="w-96">
<SelectorFormField
defaultValue={`Default (${defaultModelName})`}
defaultValue={`User default`}
name="llm_model_provider_override"
options={llmProviders.map((llmProvider) => ({
name: llmProvider.name,

View File

@@ -63,8 +63,10 @@ import Dropzone from "react-dropzone";
import {
checkLLMSupportsImageInput,
getFinalLLM,
destructureValue,
getLLMProviderOverrideForPersona,
} from "@/lib/llm/utils";
import { ChatInputBar } from "./input/ChatInputBar";
import { useChatContext } from "@/components/context/ChatContext";
import { v4 as uuidv4 } from "uuid";
@@ -76,6 +78,7 @@ import { useSidebarVisibility } from "@/components/chat_search/hooks";
import { SIDEBAR_TOGGLED_COOKIE_NAME } from "@/components/resizable/constants";
import FixedLogo from "./shared_chat_search/FixedLogo";
import { getSecondsUntilExpiration } from "@/lib/time";
import { SetDefaultModelModal } from "./modal/SetDefaultModelModal";
const TEMP_USER_MESSAGE_ID = -1;
const TEMP_ASSISTANT_MESSAGE_ID = -2;
@@ -165,14 +168,16 @@ export function ChatPage({
availableAssistants.find((assistant) => assistant.id === assistantId)
);
};
const liveAssistant =
selectedAssistant || filteredAssistants[0] || availableAssistants[0];
const llmOverrideManager = useLlmOverride(
user?.preferences.default_model,
selectedChatSession,
defaultTemperature
);
const liveAssistant =
selectedAssistant || filteredAssistants[0] || availableAssistants[0];
useEffect(() => {
const personaDefault = getLLMProviderOverrideForPersona(
liveAssistant,
@@ -181,6 +186,10 @@ export function ChatPage({
if (personaDefault) {
llmOverrideManager.setLlmOverride(personaDefault);
} else if (user?.preferences.default_model) {
llmOverrideManager.setLlmOverride(
destructureValue(user?.preferences.default_model)
);
}
}, [liveAssistant]);
@@ -785,7 +794,9 @@ export function ChatPage({
const currentAssistantId = alternativeAssistantOverride
? alternativeAssistantOverride.id
: (alternativeAssistant?.id ?? liveAssistant.id);
: alternativeAssistant
? alternativeAssistant.id
: liveAssistant.id;
resetInputBar();
@@ -831,10 +842,14 @@ export function ChatPage({
queryOverride,
forceSearch,
modelProvider: llmOverrideManager.llmOverride.name || undefined,
modelProvider:
llmOverrideManager.llmOverride.name ||
llmOverrideManager.globalDefault.name ||
undefined,
modelVersion:
llmOverrideManager.llmOverride.modelName ||
searchParams.get(SEARCH_PARAM_NAMES.MODEL_VERSION) ||
llmOverrideManager.globalDefault.modelName ||
undefined,
temperature: llmOverrideManager.temperature || undefined,
systemPromptOverride:
@@ -958,6 +973,7 @@ export function ChatPage({
completeMessageMapOverride: frozenMessageMap,
});
}
setIsStreaming(false);
if (isNewSession) {
if (finalMessage) {
@@ -1151,6 +1167,7 @@ export function ChatPage({
});
const innerSidebarElementRef = useRef<HTMLDivElement>(null);
const [settingsToggled, setSettingsToggled] = useState(false);
const currentPersona = alternativeAssistant || liveAssistant;
@@ -1185,6 +1202,8 @@ export function ChatPage({
<InstantSSRAutoRefresh />
{/* ChatPopup is a custom popup that displays a admin-specified message on initial user visit.
Only used in the EE version of the app. */}
{popup}
<ChatPopup />
{currentFeedback && (
<FeedbackModal
@@ -1201,6 +1220,15 @@ export function ChatPage({
}}
/>
)}
{settingsToggled && (
<SetDefaultModelModal
setLlmOverride={llmOverrideManager.setGlobalDefault}
defaultModel={user?.preferences.default_model!}
llmProviders={llmProviders}
onClose={() => setSettingsToggled(false)}
/>
)}
{sharingModalVisible && chatSessionIdRef.current !== null && (
<ShareChatSessionModal
chatSessionId={chatSessionIdRef.current}
@@ -1258,8 +1286,6 @@ export function ChatPage({
ref={masterFlexboxRef}
className="flex h-full w-full overflow-x-hidden"
>
{popup}
<div className="flex h-full flex-col w-full">
{liveAssistant && (
<FunctionalHeader
@@ -1641,6 +1667,7 @@ export function ChatPage({
)}
<ChatInputBar
openModelSettings={() => setSettingsToggled(true)}
inputPrompts={userInputPrompts}
showDocs={() => setDocumentSelection(true)}
selectedDocuments={selectedDocuments}

View File

@@ -34,6 +34,7 @@ import { SettingsContext } from "@/components/settings/SettingsProvider";
const MAX_INPUT_HEIGHT = 200;
export function ChatInputBar({
openModelSettings,
showDocs,
selectedDocuments,
message,
@@ -57,6 +58,7 @@ export function ChatInputBar({
chatSessionId,
inputPrompts,
}: {
openModelSettings: () => void;
showDocs: () => void;
selectedDocuments: DanswerDocument[];
assistantOptions: Persona[];
@@ -507,7 +509,6 @@ export function ChatInputBar({
}}
suppressContentEditableWarning={true}
/>
<div className="flex items-center space-x-3 mr-12 px-4 pb-2 ">
<Popup
removePadding
@@ -534,15 +535,16 @@ export function ChatInputBar({
Icon={AssistantsIconSkeleton as IconType}
/>
</Popup>
<Popup
tab
content={(close, ref) => (
<LlmTab
openModelSettings={openModelSettings}
currentLlm={
llmOverrideManager.llmOverride.modelName ||
(selectedAssistant
? selectedAssistant.llm_model_version_override ||
llmOverrideManager.globalDefault.modelName ||
llmName
: llmName)
}
@@ -564,6 +566,7 @@ export function ChatInputBar({
llmOverrideManager.llmOverride.modelName ||
(selectedAssistant
? selectedAssistant.llm_model_version_override ||
llmOverrideManager.globalDefault.modelName ||
llmName
: llmName)
)

View File

@@ -0,0 +1,174 @@
import { Dispatch, SetStateAction, useState } from "react";
import { ModalWrapper } from "./ModalWrapper";
import { Badge, Text } from "@tremor/react";
import {
getDisplayNameForModel,
LlmOverride,
LlmOverrideManager,
useLlmOverride,
} from "@/lib/hooks";
import { LLMProviderDescriptor } from "@/app/admin/models/llm/interfaces";
import { destructureValue, structureValue } from "@/lib/llm/utils";
import { setUserDefaultModel } from "@/lib/users/UserSettings";
import { useRouter } from "next/navigation";
import { usePopup } from "@/components/admin/connectors/Popup";
export function SetDefaultModelModal({
llmProviders,
onClose,
setLlmOverride,
defaultModel,
}: {
llmProviders: LLMProviderDescriptor[];
setLlmOverride: Dispatch<SetStateAction<LlmOverride>>;
onClose: () => void;
defaultModel: string | null;
}) {
const { popup, setPopup } = usePopup();
const defaultModelDestructured = defaultModel
? destructureValue(defaultModel)
: null;
const modelOptionsByProvider = new Map<
string,
{ name: string; value: string }[]
>();
llmProviders.forEach((llmProvider) => {
const providerOptions = llmProvider.model_names.map(
(modelName: string) => ({
name: getDisplayNameForModel(modelName),
value: modelName,
})
);
modelOptionsByProvider.set(llmProvider.name, providerOptions);
});
const llmOptionsByProvider: {
[provider: string]: { name: string; value: string }[];
} = {};
const uniqueModelNames = new Set<string>();
llmProviders.forEach((llmProvider) => {
if (!llmOptionsByProvider[llmProvider.provider]) {
llmOptionsByProvider[llmProvider.provider] = [];
}
(llmProvider.display_model_names || llmProvider.model_names).forEach(
(modelName) => {
if (!uniqueModelNames.has(modelName)) {
uniqueModelNames.add(modelName);
llmOptionsByProvider[llmProvider.provider].push({
name: modelName,
value: structureValue(
llmProvider.name,
llmProvider.provider,
modelName
),
});
}
}
);
});
const llmOptions = Object.entries(llmOptionsByProvider).flatMap(
([provider, options]) => [...options]
);
const router = useRouter();
const handleChangedefaultModel = async (defaultModel: string | null) => {
try {
const response = await setUserDefaultModel(defaultModel);
if (response.ok) {
if (defaultModel) {
setLlmOverride(destructureValue(defaultModel));
}
setPopup({
message: "Default model updated successfully",
type: "success",
});
router.refresh();
} else {
throw new Error("Failed to update default model");
}
} catch (error) {
setPopup({
message: "Failed to update default model",
type: "error",
});
}
};
return (
<ModalWrapper
onClose={onClose}
modalClassName="rounded-lg bg-white max-w-xl"
>
<>
{popup}
<div className="flex mb-4">
<h2 className="text-2xl text-emphasis font-bold flex my-auto">
Set Default Model
</h2>
</div>
<Text className="mb-4">
Choose a Large Language Model (LLM) to serve as the default for
assistants that don&apos;t have a default model assigned.
{defaultModel == null && " No default model has been selected!"}
</Text>
<div className="w-full flex text-sm flex-col">
<div key={-1} className="w-full border-b hover:bg-background-50">
<td className="min-w-[80px]">
{defaultModel == null ? (
<Badge>selected</Badge>
) : (
<input
type="radio"
name="credentialSelection"
onChange={(e) => {
e.preventDefault();
handleChangedefaultModel(null);
}}
className="form-radio ml-4 h-4 w-4 text-blue-600 transition duration-150 ease-in-out"
/>
)}
</td>
<td className="p-2">System default</td>
</div>
{llmOptions.map(({ name, value }, index) => {
return (
<div
key={index}
className="w-full border-b hover:bg-background-50"
>
<td className="min-w-[80px]">
{defaultModelDestructured?.modelName != name ? (
<input
type="radio"
name="credentialSelection"
onChange={(e) => {
e.preventDefault();
handleChangedefaultModel(value);
}}
className="form-radio ml-4 h-4 w-4 text-blue-600 transition duration-150 ease-in-out"
/>
) : (
<Badge>selected</Badge>
)}
</td>
<td className="p-2">
{getDisplayNameForModel(name)}{" "}
{defaultModelDestructured &&
defaultModelDestructured.name == name &&
"(selected)"}
</td>
</div>
);
})}
</div>
</>
</ModalWrapper>
);
}

View File

@@ -8,20 +8,28 @@ import { Persona } from "@/app/admin/assistants/interfaces";
import { destructureValue, getFinalLLM, structureValue } from "@/lib/llm/utils";
import { updateModelOverrideForChatSession } from "../../lib";
import { Tooltip } from "@/components/tooltip/Tooltip";
import { InfoIcon } from "@/components/icons/icons";
import { GearIcon, InfoIcon } from "@/components/icons/icons";
import { CustomTooltip } from "@/components/tooltip/CustomTooltip";
interface LlmTabProps {
llmOverrideManager: LlmOverrideManager;
currentAssistant: Persona;
currentLlm: string;
openModelSettings: () => void;
chatSessionId?: number;
close: () => void;
}
export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
(
{ llmOverrideManager, currentAssistant, chatSessionId, currentLlm, close },
{
llmOverrideManager,
currentAssistant,
chatSessionId,
currentLlm,
close,
openModelSettings,
},
ref
) => {
const { llmProviders } = useChatContext();
@@ -43,12 +51,6 @@ export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
debouncedSetTemperature(value);
};
const [_, defaultLlmName] = getFinalLLM(
llmProviders,
currentAssistant,
null
);
const llmOptionsByProvider: {
[provider: string]: { name: string; value: string }[];
} = {};
@@ -81,8 +83,16 @@ export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
);
return (
<div className="w-full">
<div className="flex w-full content-center gap-x-2">
<label className="block text-sm font-medium mb-2">Choose Model</label>
<div className="flex w-full justify-between content-center mb-2 gap-x-2">
<label className="block text-sm font-medium ">Choose Model</label>
<button
onClick={() => {
close();
openModelSettings();
}}
>
<GearIcon />
</button>
</div>
<div className="max-h-[300px] flex flex-col gap-y-1 overflow-y-scroll">
{llmOptions.map(({ name, value }, index) => {

View File

@@ -9,7 +9,11 @@ import { checkUserIsNoAuthUser, logout } from "@/lib/user";
import { Popover } from "./popover/Popover";
import { LOGOUT_DISABLED } from "@/lib/constants";
import { SettingsContext } from "./settings/SettingsProvider";
import { LightSettingsIcon } from "./icons/icons";
import {
AssistantsIconSkeleton,
LightSettingsIcon,
UsersIcon,
} from "./icons/icons";
import { pageType } from "@/app/chat/sessionSidebar/types";
export function UserDropdown({
@@ -105,6 +109,7 @@ export function UserDropdown({
</Link>
</>
)}
{showLogout && (
<>
{(!(page == "search" || page == "chat") || showAdminPanel) && (

View File

@@ -137,15 +137,27 @@ export interface LlmOverride {
export interface LlmOverrideManager {
llmOverride: LlmOverride;
setLlmOverride: React.Dispatch<React.SetStateAction<LlmOverride>>;
globalDefault: LlmOverride;
setGlobalDefault: React.Dispatch<React.SetStateAction<LlmOverride>>;
temperature: number | null;
setTemperature: React.Dispatch<React.SetStateAction<number | null>>;
updateModelOverrideForChatSession: (chatSession?: ChatSession) => void;
}
export function useLlmOverride(
globalModel?: string | null,
currentChatSession?: ChatSession,
defaultTemperature?: number
): LlmOverrideManager {
const [globalDefault, setGlobalDefault] = useState<LlmOverride>(
globalModel
? destructureValue(globalModel)
: {
name: "",
provider: "",
modelName: "",
}
);
const [llmOverride, setLlmOverride] = useState<LlmOverride>(
currentChatSession && currentChatSession.current_alternate_model
? destructureValue(currentChatSession.current_alternate_model)
@@ -160,11 +172,7 @@ export function useLlmOverride(
setLlmOverride(
chatSession && chatSession.current_alternate_model
? destructureValue(chatSession.current_alternate_model)
: {
name: "",
provider: "",
modelName: "",
}
: globalDefault
);
};
@@ -180,11 +188,12 @@ export function useLlmOverride(
updateModelOverrideForChatSession,
llmOverride,
setLlmOverride,
globalDefault,
setGlobalDefault,
temperature,
setTemperature,
};
}
/*
EE Only APIs
*/

View File

@@ -4,6 +4,7 @@ import { Connector } from "./connectors/connectors";
export interface UserPreferences {
chosen_assistants: number[] | null;
default_model: string | null;
}
export enum UserStatus {

View File

@@ -0,0 +1,15 @@
import { LlmOverride } from "../hooks";
export async function setUserDefaultModel(
model: string | null
): Promise<Response> {
const response = await fetch(`/api/user/default-model`, {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ default_model: model }),
});
return response;
}