From 7d5cfd2fa30e8e529272ff55bd1f93c1da212a53 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Sat, 10 Aug 2024 14:37:33 -0700 Subject: [PATCH] Add user specific model defaults (#2043) --- ...a5f5d728_added_model_defaults_for_users.py | 24 +++ backend/danswer/auth/noauth_user.py | 2 +- backend/danswer/db/models.py | 4 + backend/danswer/server/manage/models.py | 8 +- backend/danswer/server/manage/users.py | 28 +++ .../ee/danswer/server/user_group/models.py | 3 +- .../app/admin/assistants/AssistantEditor.tsx | 10 +- web/src/app/chat/ChatPage.tsx | 39 +++- web/src/app/chat/input/ChatInputBar.tsx | 7 +- .../app/chat/modal/SetDefaultModelModal.tsx | 174 ++++++++++++++++++ .../app/chat/modal/configuration/LlmTab.tsx | 30 ++- web/src/components/UserDropdown.tsx | 7 +- web/src/lib/hooks.ts | 23 ++- web/src/lib/types.ts | 1 + web/src/lib/users/UserSettings.tsx | 15 ++ 15 files changed, 342 insertions(+), 33 deletions(-) create mode 100644 backend/alembic/versions/7477a5f5d728_added_model_defaults_for_users.py create mode 100644 web/src/app/chat/modal/SetDefaultModelModal.tsx create mode 100644 web/src/lib/users/UserSettings.tsx diff --git a/backend/alembic/versions/7477a5f5d728_added_model_defaults_for_users.py b/backend/alembic/versions/7477a5f5d728_added_model_defaults_for_users.py new file mode 100644 index 000000000000..b15637020d97 --- /dev/null +++ b/backend/alembic/versions/7477a5f5d728_added_model_defaults_for_users.py @@ -0,0 +1,24 @@ +"""Added model defaults for users + +Revision ID: 7477a5f5d728 +Revises: 213fd978c6d8 +Create Date: 2024-08-04 19:00:04.512634 + +""" + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "7477a5f5d728" +down_revision = "213fd978c6d8" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column("user", sa.Column("default_model", sa.Text(), nullable=True)) + + +def downgrade() -> None: + op.drop_column("user", "default_model") diff --git a/backend/danswer/auth/noauth_user.py b/backend/danswer/auth/noauth_user.py index 55fdbe4a5569..11b04bf40e96 100644 --- a/backend/danswer/auth/noauth_user.py +++ b/backend/danswer/auth/noauth_user.py @@ -23,7 +23,7 @@ def load_no_auth_user_preferences(store: DynamicConfigStore) -> UserPreferences: ) return UserPreferences(**preferences_data) except ConfigNotFoundError: - return UserPreferences(chosen_assistants=None) + return UserPreferences(chosen_assistants=None, default_model=None) def fetch_no_auth_user(store: DynamicConfigStore) -> UserInfo: diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 0fa09c5344be..dc0810e76ceb 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -125,6 +125,10 @@ class User(SQLAlchemyBaseUserTableUUID, Base): TIMESTAMPAware(timezone=True), nullable=True ) + default_model: Mapped[str] = mapped_column(Text, nullable=True) + # organized in typical structured fashion + # formatted as `displayName__provider__modelName` + # relationships credentials: Mapped[list["Credential"]] = relationship( "Credential", back_populates="user", lazy="joined" diff --git a/backend/danswer/server/manage/models.py b/backend/danswer/server/manage/models.py index ee8a6ed6ed42..bf39813ac5bb 100644 --- a/backend/danswer/server/manage/models.py +++ b/backend/danswer/server/manage/models.py @@ -39,6 +39,7 @@ class AuthTypeResponse(BaseModel): class UserPreferences(BaseModel): chosen_assistants: list[int] | None + default_model: str | None class UserInfo(BaseModel): @@ -67,7 +68,12 @@ class UserInfo(BaseModel): is_superuser=user.is_superuser, is_verified=user.is_verified, role=user.role, - preferences=(UserPreferences(chosen_assistants=user.chosen_assistants)), + preferences=( + UserPreferences( + chosen_assistants=user.chosen_assistants, + default_model=user.default_model, + ) + ), oidc_expiry=user.oidc_expiry, current_token_created_at=current_token_created_at, current_token_expiry_length=expiry_length, diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index 0ba92ea2f4f2..c50df365dd8f 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -316,6 +316,34 @@ def verify_user_logged_in( """APIs to adjust user preferences""" +class ChosenDefaultModelRequest(BaseModel): + default_model: str | None + + +@router.patch("/user/default-model") +def update_user_default_model( + request: ChosenDefaultModelRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> None: + if user is None: + if AUTH_TYPE == AuthType.DISABLED: + store = get_dynamic_config_store() + no_auth_user = fetch_no_auth_user(store) + no_auth_user.preferences.default_model = request.default_model + set_no_auth_user_preferences(store, no_auth_user.preferences) + return + else: + raise RuntimeError("This should never happen") + + db_session.execute( + update(User) + .where(User.id == user.id) # type: ignore + .values(default_model=request.default_model) + ) + db_session.commit() + + class ChosenAssistantsRequest(BaseModel): chosen_assistants: list[int] diff --git a/backend/ee/danswer/server/user_group/models.py b/backend/ee/danswer/server/user_group/models.py index fa1359469cb7..22a6d55f511a 100644 --- a/backend/ee/danswer/server/user_group/models.py +++ b/backend/ee/danswer/server/user_group/models.py @@ -36,7 +36,8 @@ class UserGroup(BaseModel): is_verified=user.is_verified, role=user.role, preferences=UserPreferences( - chosen_assistants=user.chosen_assistants + default_model=user.default_model, + chosen_assistants=user.chosen_assistants, ), ) for user in user_group_model.users diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index 7f6d99a1513d..00b2ba13f7a6 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -28,7 +28,7 @@ import { GroupsIcon, PaintingIcon, SwapIcon } from "@/components/icons/icons"; import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled"; import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences"; import { useUserGroups } from "@/lib/hooks"; -import { checkLLMSupportsImageInput } from "@/lib/llm/utils"; +import { checkLLMSupportsImageInput, destructureValue } from "@/lib/llm/utils"; import { ToolSnapshot } from "@/lib/tools/interfaces"; import { checkUserIsNoAuthUser } from "@/lib/user"; import { @@ -539,13 +539,15 @@ export function AssistantEditor({

- You assistant will use your system default (currently{" "} - {defaultModelName}) unless otherwise specified below. + Your assistant will use the user's set default unless + otherwise specified below. + {user?.preferences.default_model && + ` Your current (user-specific) default model is ${getDisplayNameForModel(destructureValue(user?.preferences?.default_model!).modelName)}`}

({ name: llmProvider.name, diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index b1f79f9d0d7e..925454786c56 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -63,8 +63,10 @@ import Dropzone from "react-dropzone"; import { checkLLMSupportsImageInput, getFinalLLM, + destructureValue, getLLMProviderOverrideForPersona, } from "@/lib/llm/utils"; + import { ChatInputBar } from "./input/ChatInputBar"; import { useChatContext } from "@/components/context/ChatContext"; import { v4 as uuidv4 } from "uuid"; @@ -76,6 +78,7 @@ import { useSidebarVisibility } from "@/components/chat_search/hooks"; import { SIDEBAR_TOGGLED_COOKIE_NAME } from "@/components/resizable/constants"; import FixedLogo from "./shared_chat_search/FixedLogo"; import { getSecondsUntilExpiration } from "@/lib/time"; +import { SetDefaultModelModal } from "./modal/SetDefaultModelModal"; const TEMP_USER_MESSAGE_ID = -1; const TEMP_ASSISTANT_MESSAGE_ID = -2; @@ -165,14 +168,16 @@ export function ChatPage({ availableAssistants.find((assistant) => assistant.id === assistantId) ); }; - const liveAssistant = - selectedAssistant || filteredAssistants[0] || availableAssistants[0]; const llmOverrideManager = useLlmOverride( + user?.preferences.default_model, selectedChatSession, defaultTemperature ); + const liveAssistant = + selectedAssistant || filteredAssistants[0] || availableAssistants[0]; + useEffect(() => { const personaDefault = getLLMProviderOverrideForPersona( liveAssistant, @@ -181,6 +186,10 @@ export function ChatPage({ if (personaDefault) { llmOverrideManager.setLlmOverride(personaDefault); + } else if (user?.preferences.default_model) { + llmOverrideManager.setLlmOverride( + destructureValue(user?.preferences.default_model) + ); } }, [liveAssistant]); @@ -785,7 +794,9 @@ export function ChatPage({ const currentAssistantId = alternativeAssistantOverride ? alternativeAssistantOverride.id - : (alternativeAssistant?.id ?? liveAssistant.id); + : alternativeAssistant + ? alternativeAssistant.id + : liveAssistant.id; resetInputBar(); @@ -831,10 +842,14 @@ export function ChatPage({ queryOverride, forceSearch, - modelProvider: llmOverrideManager.llmOverride.name || undefined, + modelProvider: + llmOverrideManager.llmOverride.name || + llmOverrideManager.globalDefault.name || + undefined, modelVersion: llmOverrideManager.llmOverride.modelName || searchParams.get(SEARCH_PARAM_NAMES.MODEL_VERSION) || + llmOverrideManager.globalDefault.modelName || undefined, temperature: llmOverrideManager.temperature || undefined, systemPromptOverride: @@ -958,6 +973,7 @@ export function ChatPage({ completeMessageMapOverride: frozenMessageMap, }); } + setIsStreaming(false); if (isNewSession) { if (finalMessage) { @@ -1151,6 +1167,7 @@ export function ChatPage({ }); const innerSidebarElementRef = useRef(null); + const [settingsToggled, setSettingsToggled] = useState(false); const currentPersona = alternativeAssistant || liveAssistant; @@ -1185,6 +1202,8 @@ export function ChatPage({ {/* ChatPopup is a custom popup that displays a admin-specified message on initial user visit. Only used in the EE version of the app. */} + {popup} + {currentFeedback && ( )} + + {settingsToggled && ( + setSettingsToggled(false)} + /> + )} {sharingModalVisible && chatSessionIdRef.current !== null && ( - {popup} -
{liveAssistant && ( setSettingsToggled(true)} inputPrompts={userInputPrompts} showDocs={() => setDocumentSelection(true)} selectedDocuments={selectedDocuments} diff --git a/web/src/app/chat/input/ChatInputBar.tsx b/web/src/app/chat/input/ChatInputBar.tsx index 5067d1ae6d77..9c1b64aba233 100644 --- a/web/src/app/chat/input/ChatInputBar.tsx +++ b/web/src/app/chat/input/ChatInputBar.tsx @@ -34,6 +34,7 @@ import { SettingsContext } from "@/components/settings/SettingsProvider"; const MAX_INPUT_HEIGHT = 200; export function ChatInputBar({ + openModelSettings, showDocs, selectedDocuments, message, @@ -57,6 +58,7 @@ export function ChatInputBar({ chatSessionId, inputPrompts, }: { + openModelSettings: () => void; showDocs: () => void; selectedDocuments: DanswerDocument[]; assistantOptions: Persona[]; @@ -507,7 +509,6 @@ export function ChatInputBar({ }} suppressContentEditableWarning={true} /> -
- ( >; + onClose: () => void; + defaultModel: string | null; +}) { + const { popup, setPopup } = usePopup(); + + const defaultModelDestructured = defaultModel + ? destructureValue(defaultModel) + : null; + const modelOptionsByProvider = new Map< + string, + { name: string; value: string }[] + >(); + llmProviders.forEach((llmProvider) => { + const providerOptions = llmProvider.model_names.map( + (modelName: string) => ({ + name: getDisplayNameForModel(modelName), + value: modelName, + }) + ); + modelOptionsByProvider.set(llmProvider.name, providerOptions); + }); + + const llmOptionsByProvider: { + [provider: string]: { name: string; value: string }[]; + } = {}; + const uniqueModelNames = new Set(); + + llmProviders.forEach((llmProvider) => { + if (!llmOptionsByProvider[llmProvider.provider]) { + llmOptionsByProvider[llmProvider.provider] = []; + } + + (llmProvider.display_model_names || llmProvider.model_names).forEach( + (modelName) => { + if (!uniqueModelNames.has(modelName)) { + uniqueModelNames.add(modelName); + llmOptionsByProvider[llmProvider.provider].push({ + name: modelName, + value: structureValue( + llmProvider.name, + llmProvider.provider, + modelName + ), + }); + } + } + ); + }); + + const llmOptions = Object.entries(llmOptionsByProvider).flatMap( + ([provider, options]) => [...options] + ); + + const router = useRouter(); + const handleChangedefaultModel = async (defaultModel: string | null) => { + try { + const response = await setUserDefaultModel(defaultModel); + + if (response.ok) { + if (defaultModel) { + setLlmOverride(destructureValue(defaultModel)); + } + setPopup({ + message: "Default model updated successfully", + type: "success", + }); + router.refresh(); + } else { + throw new Error("Failed to update default model"); + } + } catch (error) { + setPopup({ + message: "Failed to update default model", + type: "error", + }); + } + }; + + return ( + + <> + {popup} +
+

+ Set Default Model +

+
+ + + Choose a Large Language Model (LLM) to serve as the default for + assistants that don't have a default model assigned. + {defaultModel == null && " No default model has been selected!"} + +
+
+ + {defaultModel == null ? ( + selected + ) : ( + { + e.preventDefault(); + handleChangedefaultModel(null); + }} + className="form-radio ml-4 h-4 w-4 text-blue-600 transition duration-150 ease-in-out" + /> + )} + + System default +
+ + {llmOptions.map(({ name, value }, index) => { + return ( +
+ + {defaultModelDestructured?.modelName != name ? ( + { + e.preventDefault(); + handleChangedefaultModel(value); + }} + className="form-radio ml-4 h-4 w-4 text-blue-600 transition duration-150 ease-in-out" + /> + ) : ( + selected + )} + + + {getDisplayNameForModel(name)}{" "} + {defaultModelDestructured && + defaultModelDestructured.name == name && + "(selected)"} + +
+ ); + })} +
+ +
+ ); +} diff --git a/web/src/app/chat/modal/configuration/LlmTab.tsx b/web/src/app/chat/modal/configuration/LlmTab.tsx index b7538ee28fd4..27cd8cb0958e 100644 --- a/web/src/app/chat/modal/configuration/LlmTab.tsx +++ b/web/src/app/chat/modal/configuration/LlmTab.tsx @@ -8,20 +8,28 @@ import { Persona } from "@/app/admin/assistants/interfaces"; import { destructureValue, getFinalLLM, structureValue } from "@/lib/llm/utils"; import { updateModelOverrideForChatSession } from "../../lib"; import { Tooltip } from "@/components/tooltip/Tooltip"; -import { InfoIcon } from "@/components/icons/icons"; +import { GearIcon, InfoIcon } from "@/components/icons/icons"; import { CustomTooltip } from "@/components/tooltip/CustomTooltip"; interface LlmTabProps { llmOverrideManager: LlmOverrideManager; currentAssistant: Persona; currentLlm: string; + openModelSettings: () => void; chatSessionId?: number; close: () => void; } export const LlmTab = forwardRef( ( - { llmOverrideManager, currentAssistant, chatSessionId, currentLlm, close }, + { + llmOverrideManager, + currentAssistant, + chatSessionId, + currentLlm, + close, + openModelSettings, + }, ref ) => { const { llmProviders } = useChatContext(); @@ -43,12 +51,6 @@ export const LlmTab = forwardRef( debouncedSetTemperature(value); }; - const [_, defaultLlmName] = getFinalLLM( - llmProviders, - currentAssistant, - null - ); - const llmOptionsByProvider: { [provider: string]: { name: string; value: string }[]; } = {}; @@ -81,8 +83,16 @@ export const LlmTab = forwardRef( ); return (
-
- +
+ +
{llmOptions.map(({ name, value }, index) => { diff --git a/web/src/components/UserDropdown.tsx b/web/src/components/UserDropdown.tsx index 185c9862486c..cab1540bc979 100644 --- a/web/src/components/UserDropdown.tsx +++ b/web/src/components/UserDropdown.tsx @@ -9,7 +9,11 @@ import { checkUserIsNoAuthUser, logout } from "@/lib/user"; import { Popover } from "./popover/Popover"; import { LOGOUT_DISABLED } from "@/lib/constants"; import { SettingsContext } from "./settings/SettingsProvider"; -import { LightSettingsIcon } from "./icons/icons"; +import { + AssistantsIconSkeleton, + LightSettingsIcon, + UsersIcon, +} from "./icons/icons"; import { pageType } from "@/app/chat/sessionSidebar/types"; export function UserDropdown({ @@ -105,6 +109,7 @@ export function UserDropdown({ )} + {showLogout && ( <> {(!(page == "search" || page == "chat") || showAdminPanel) && ( diff --git a/web/src/lib/hooks.ts b/web/src/lib/hooks.ts index 050fd287cfa5..5bbc9c5f4625 100644 --- a/web/src/lib/hooks.ts +++ b/web/src/lib/hooks.ts @@ -137,15 +137,27 @@ export interface LlmOverride { export interface LlmOverrideManager { llmOverride: LlmOverride; setLlmOverride: React.Dispatch>; + globalDefault: LlmOverride; + setGlobalDefault: React.Dispatch>; temperature: number | null; setTemperature: React.Dispatch>; updateModelOverrideForChatSession: (chatSession?: ChatSession) => void; } - export function useLlmOverride( + globalModel?: string | null, currentChatSession?: ChatSession, defaultTemperature?: number ): LlmOverrideManager { + const [globalDefault, setGlobalDefault] = useState( + globalModel + ? destructureValue(globalModel) + : { + name: "", + provider: "", + modelName: "", + } + ); + const [llmOverride, setLlmOverride] = useState( currentChatSession && currentChatSession.current_alternate_model ? destructureValue(currentChatSession.current_alternate_model) @@ -160,11 +172,7 @@ export function useLlmOverride( setLlmOverride( chatSession && chatSession.current_alternate_model ? destructureValue(chatSession.current_alternate_model) - : { - name: "", - provider: "", - modelName: "", - } + : globalDefault ); }; @@ -180,11 +188,12 @@ export function useLlmOverride( updateModelOverrideForChatSession, llmOverride, setLlmOverride, + globalDefault, + setGlobalDefault, temperature, setTemperature, }; } - /* EE Only APIs */ diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index c717f2c57db2..6a6bdfa74a4e 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -4,6 +4,7 @@ import { Connector } from "./connectors/connectors"; export interface UserPreferences { chosen_assistants: number[] | null; + default_model: string | null; } export enum UserStatus { diff --git a/web/src/lib/users/UserSettings.tsx b/web/src/lib/users/UserSettings.tsx new file mode 100644 index 000000000000..c99a23917bf3 --- /dev/null +++ b/web/src/lib/users/UserSettings.tsx @@ -0,0 +1,15 @@ +import { LlmOverride } from "../hooks"; + +export async function setUserDefaultModel( + model: string | null +): Promise { + const response = await fetch(`/api/user/default-model`, { + method: "PATCH", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ default_model: model }), + }); + + return response; +}