From 76a5f26fe127229ecaa67cc28aae11877fdb1eb7 Mon Sep 17 00:00:00 2001 From: Weves Date: Tue, 7 May 2024 15:33:41 -0700 Subject: [PATCH] Add display names to LLMProvider + allow multiple configs from the same provider --- ...dd_user_configured_names_to_llmprovider.py | 45 +++++ backend/danswer/db/llm.py | 2 + backend/danswer/db/models.py | 1 + .../danswer/dynamic_configs/port_configs.py | 1 + backend/danswer/llm/factory.py | 19 +- backend/danswer/server/manage/llm/models.py | 6 +- .../app/admin/assistants/AssistantEditor.tsx | 12 +- .../llm/ConfiguredLLMProviderDisplay.tsx | 170 ++++++++++++++++++ .../llm/CustomLLMProviderUpdateForm.tsx | 15 +- .../app/admin/models/llm/LLMConfiguration.tsx | 119 +++++------- .../models/llm/LLMProviderUpdateForm.tsx | 13 +- web/src/app/admin/models/llm/interfaces.ts | 1 + 12 files changed, 299 insertions(+), 105 deletions(-) create mode 100644 backend/alembic/versions/643a84a42a33_add_user_configured_names_to_llmprovider.py create mode 100644 web/src/app/admin/models/llm/ConfiguredLLMProviderDisplay.tsx diff --git a/backend/alembic/versions/643a84a42a33_add_user_configured_names_to_llmprovider.py b/backend/alembic/versions/643a84a42a33_add_user_configured_names_to_llmprovider.py new file mode 100644 index 000000000..1a35ff6de --- /dev/null +++ b/backend/alembic/versions/643a84a42a33_add_user_configured_names_to_llmprovider.py @@ -0,0 +1,45 @@ +"""Add user-configured names to LLMProvider + +Revision ID: 643a84a42a33 +Revises: 0a98909f2757 +Create Date: 2024-05-07 14:54:55.493100 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "643a84a42a33" +down_revision = "0a98909f2757" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column("llm_provider", sa.Column("provider", sa.String(), nullable=True)) + # move "name" -> "provider" to match the new schema + op.execute("UPDATE llm_provider SET provider = name") + # pretty up display name + op.execute("UPDATE llm_provider SET name = 'OpenAI' WHERE name = 'openai'") + op.execute("UPDATE llm_provider SET name = 'Anthropic' WHERE name = 'anthropic'") + op.execute("UPDATE llm_provider SET name = 'Azure OpenAI' WHERE name = 'azure'") + op.execute("UPDATE llm_provider SET name = 'AWS Bedrock' WHERE name = 'bedrock'") + + # update personas to use the new provider names + op.execute( + "UPDATE persona SET llm_model_provider_override = 'OpenAI' WHERE llm_model_provider_override = 'openai'" + ) + op.execute( + "UPDATE persona SET llm_model_provider_override = 'Anthropic' WHERE llm_model_provider_override = 'anthropic'" + ) + op.execute( + "UPDATE persona SET llm_model_provider_override = 'Azure OpenAI' WHERE llm_model_provider_override = 'azure'" + ) + op.execute( + "UPDATE persona SET llm_model_provider_override = 'AWS Bedrock' WHERE llm_model_provider_override = 'bedrock'" + ) + + +def downgrade() -> None: + op.execute("UPDATE llm_provider SET name = provider") + op.drop_column("llm_provider", "provider") diff --git a/backend/danswer/db/llm.py b/backend/danswer/db/llm.py index a502180db..f969dbf68 100644 --- a/backend/danswer/db/llm.py +++ b/backend/danswer/db/llm.py @@ -14,6 +14,7 @@ def upsert_llm_provider( select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name) ) if existing_llm_provider: + existing_llm_provider.provider = llm_provider.provider existing_llm_provider.api_key = llm_provider.api_key existing_llm_provider.api_base = llm_provider.api_base existing_llm_provider.api_version = llm_provider.api_version @@ -29,6 +30,7 @@ def upsert_llm_provider( # if it does not exist, create a new entry llm_provider_model = LLMProviderModel( name=llm_provider.name, + provider=llm_provider.provider, api_key=llm_provider.api_key, api_base=llm_provider.api_base, api_version=llm_provider.api_version, diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index c62aa2a0b..1a2b615a5 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -779,6 +779,7 @@ class LLMProvider(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) name: Mapped[str] = mapped_column(String, unique=True) + provider: Mapped[str] = mapped_column(String) api_key: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True) api_base: Mapped[str | None] = mapped_column(String, nullable=True) api_version: Mapped[str | None] = mapped_column(String, nullable=True) diff --git a/backend/danswer/dynamic_configs/port_configs.py b/backend/danswer/dynamic_configs/port_configs.py index d0c55e698..b28615a62 100644 --- a/backend/danswer/dynamic_configs/port_configs.py +++ b/backend/danswer/dynamic_configs/port_configs.py @@ -94,6 +94,7 @@ def port_api_key_to_postgres() -> None: llm_provider_upsert = LLMProviderUpsertRequest( name=GEN_AI_MODEL_PROVIDER, + provider=GEN_AI_MODEL_PROVIDER, api_key=api_key, api_base=GEN_AI_API_ENDPOINT, api_version=GEN_AI_API_VERSION, diff --git a/backend/danswer/llm/factory.py b/backend/danswer/llm/factory.py index 3f4172c51..9c92eb9a6 100644 --- a/backend/danswer/llm/factory.py +++ b/backend/danswer/llm/factory.py @@ -20,11 +20,10 @@ def get_llm_for_persona( temperature_override = llm_override.temperature if llm_override else None return get_default_llm( - gen_ai_model_provider=model_provider_override - or persona.llm_model_provider_override, - gen_ai_model_version_override=( - model_version_override or persona.llm_model_version_override + model_provider_name=( + model_provider_override or persona.llm_model_provider_override ), + model_version=(model_version_override or persona.llm_model_version_override), temperature=temperature_override or GEN_AI_TEMPERATURE, ) @@ -33,23 +32,23 @@ def get_default_llm( timeout: int = QA_TIMEOUT, temperature: float = GEN_AI_TEMPERATURE, use_fast_llm: bool = False, - gen_ai_model_provider: str | None = None, - gen_ai_model_version_override: str | None = None, + model_provider_name: str | None = None, + model_version: str | None = None, ) -> LLM: if DISABLE_GENERATIVE_AI: raise GenAIDisabledException() # TODO: pass this in with get_session_context_manager() as session: - if gen_ai_model_provider is None: + if model_provider_name is None: llm_provider = fetch_default_provider(session) else: - llm_provider = fetch_provider(session, gen_ai_model_provider) + llm_provider = fetch_provider(session, model_provider_name) if not llm_provider: raise ValueError("No default LLM provider found") - model_name = gen_ai_model_version_override or ( + model_name = model_version or ( (llm_provider.fast_default_model_name or llm_provider.default_model_name) if use_fast_llm else llm_provider.default_model_name @@ -58,7 +57,7 @@ def get_default_llm( raise ValueError("No default model name found") return get_llm( - provider=llm_provider.name, + provider=llm_provider.provider, model=model_name, api_key=llm_provider.api_key, api_base=llm_provider.api_base, diff --git a/backend/danswer/server/manage/llm/models.py b/backend/danswer/server/manage/llm/models.py index 628bb0a7a..54ec9d8b1 100644 --- a/backend/danswer/server/manage/llm/models.py +++ b/backend/danswer/server/manage/llm/models.py @@ -26,6 +26,7 @@ class LLMProviderDescriptor(BaseModel): non-admin users. Used when giving a list of available LLMs.""" name: str + provider: str model_names: list[str] default_model_name: str fast_default_model_name: str | None @@ -37,12 +38,13 @@ class LLMProviderDescriptor(BaseModel): ) -> "LLMProviderDescriptor": return cls( name=llm_provider_model.name, + provider=llm_provider_model.provider, default_model_name=llm_provider_model.default_model_name, fast_default_model_name=llm_provider_model.fast_default_model_name, is_default_provider=llm_provider_model.is_default_provider, model_names=( llm_provider_model.model_names - or fetch_models_for_provider(llm_provider_model.name) + or fetch_models_for_provider(llm_provider_model.provider) or [llm_provider_model.default_model_name] ), ) @@ -50,6 +52,7 @@ class LLMProviderDescriptor(BaseModel): class LLMProvider(BaseModel): name: str + provider: str api_key: str | None api_base: str | None api_version: str | None @@ -74,6 +77,7 @@ class FullLLMProvider(LLMProvider): return cls( id=llm_provider_model.id, name=llm_provider_model.name, + provider=llm_provider_model.provider, api_key=llm_provider_model.api_key, api_base=llm_provider_model.api_base, api_version=llm_provider_model.api_version, diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index bcded7031..90f4a3281 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -34,13 +34,6 @@ import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelec import { FullLLMProvider } from "../models/llm/interfaces"; import { Option } from "@/components/Dropdown"; -const DEFAULT_LLM_PROVIDER_TO_DISPLAY_NAME: Record = { - openai: "OpenAI", - azure: "Azure OpenAI", - anthropic: "Anthropic", - bedrock: "AWS Bedrock", -}; - function Label({ children }: { children: string | JSX.Element }) { return (
{children}
@@ -495,10 +488,7 @@ export function AssistantEditor({ ({ - name: - DEFAULT_LLM_PROVIDER_TO_DISPLAY_NAME[ - llmProvider.name - ] || llmProvider.name, + name: llmProvider.name, value: llmProvider.name, }))} includeDefault={true} diff --git a/web/src/app/admin/models/llm/ConfiguredLLMProviderDisplay.tsx b/web/src/app/admin/models/llm/ConfiguredLLMProviderDisplay.tsx new file mode 100644 index 000000000..ec58db7f3 --- /dev/null +++ b/web/src/app/admin/models/llm/ConfiguredLLMProviderDisplay.tsx @@ -0,0 +1,170 @@ +import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup"; +import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces"; +import { Modal } from "@/components/Modal"; +import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm"; +import { CustomLLMProviderUpdateForm } from "./CustomLLMProviderUpdateForm"; +import { useState } from "react"; +import { LLM_PROVIDERS_ADMIN_URL } from "./constants"; +import { mutate } from "swr"; +import { Badge, Button } from "@tremor/react"; + +function LLMProviderUpdateModal({ + llmProviderDescriptor, + onClose, + existingLlmProvider, + shouldMarkAsDefault, + setPopup, +}: { + llmProviderDescriptor: WellKnownLLMProviderDescriptor | null; + onClose: () => void; + existingLlmProvider?: FullLLMProvider; + shouldMarkAsDefault?: boolean; + setPopup?: (popup: PopupSpec) => void; +}) { + const providerName = + llmProviderDescriptor?.display_name || + llmProviderDescriptor?.name || + existingLlmProvider?.name || + "Custom LLM Provider"; + return ( + onClose()}> +
+ {llmProviderDescriptor ? ( + + ) : ( + + )} +
+
+ ); +} + +function LLMProviderDisplay({ + llmProviderDescriptor, + existingLlmProvider, + shouldMarkAsDefault, +}: { + llmProviderDescriptor: WellKnownLLMProviderDescriptor | null; + existingLlmProvider: FullLLMProvider; + shouldMarkAsDefault?: boolean; +}) { + const [formIsVisible, setFormIsVisible] = useState(false); + const { popup, setPopup } = usePopup(); + + const providerName = + llmProviderDescriptor?.display_name || + llmProviderDescriptor?.name || + existingLlmProvider?.name; + return ( +
+ {popup} +
+
+
{providerName}
+
({existingLlmProvider.provider})
+ {!existingLlmProvider.is_default_provider && ( +
{ + const response = await fetch( + `${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}/default`, + { + method: "POST", + } + ); + if (!response.ok) { + const errorMsg = (await response.json()).detail; + setPopup({ + type: "error", + message: `Failed to set provider as default: ${errorMsg}`, + }); + return; + } + + mutate(LLM_PROVIDERS_ADMIN_URL); + setPopup({ + type: "success", + message: "Provider set as default successfully!", + }); + }} + > + Set as default +
+ )} +
+ + {existingLlmProvider && ( +
+ {existingLlmProvider.is_default_provider ? ( + + Default + + ) : ( + + Enabled + + )} +
+ )} + +
+ +
+
+ {formIsVisible && ( + setFormIsVisible(false)} + existingLlmProvider={existingLlmProvider} + shouldMarkAsDefault={shouldMarkAsDefault} + setPopup={setPopup} + /> + )} +
+ ); +} + +export function ConfiguredLLMProviderDisplay({ + existingLlmProviders, +}: { + existingLlmProviders: FullLLMProvider[]; +}) { + existingLlmProviders = existingLlmProviders.sort((a, b) => { + if (a.is_default_provider && !b.is_default_provider) { + return -1; + } + if (!a.is_default_provider && b.is_default_provider) { + return 1; + } + return a.provider > b.provider ? 1 : -1; + }); + + return ( +
+ {existingLlmProviders.map((provider) => ( + + ))} +
+ ); +} diff --git a/web/src/app/admin/models/llm/CustomLLMProviderUpdateForm.tsx b/web/src/app/admin/models/llm/CustomLLMProviderUpdateForm.tsx index b81730ac1..7ada70203 100644 --- a/web/src/app/admin/models/llm/CustomLLMProviderUpdateForm.tsx +++ b/web/src/app/admin/models/llm/CustomLLMProviderUpdateForm.tsx @@ -53,6 +53,7 @@ export function CustomLLMProviderUpdateForm({ // Define the initial values based on the provider's requirements const initialValues = { name: existingLlmProvider?.name ?? "", + provider: existingLlmProvider?.provider ?? "", api_key: existingLlmProvider?.api_key ?? "", api_base: existingLlmProvider?.api_base ?? "", api_version: existingLlmProvider?.api_version ?? "", @@ -71,7 +72,8 @@ export function CustomLLMProviderUpdateForm({ // Setup validation schema if required const validationSchema = Yup.object({ - name: Yup.string().required("Name is required"), + name: Yup.string().required("Display Name is required"), + provider: Yup.string().required("Provider Name is required"), api_key: Yup.string(), api_base: Yup.string(), api_version: Yup.string(), @@ -185,6 +187,15 @@ export function CustomLLMProviderUpdateForm({
+ + + + @@ -384,7 +395,6 @@ export function CustomLLMProviderUpdateForm({ disabled={isTesting} onClick={async () => { setIsTesting(true); - console.log(values.custom_config_list); const response = await fetch("/api/admin/llm/test", { method: "POST", @@ -392,7 +402,6 @@ export function CustomLLMProviderUpdateForm({ "Content-Type": "application/json", }, body: JSON.stringify({ - provider: values.name, custom_config: customConfigProcessing( values.custom_config_list ), diff --git a/web/src/app/admin/models/llm/LLMConfiguration.tsx b/web/src/app/admin/models/llm/LLMConfiguration.tsx index 744579677..47a17262e 100644 --- a/web/src/app/admin/models/llm/LLMConfiguration.tsx +++ b/web/src/app/admin/models/llm/LLMConfiguration.tsx @@ -3,14 +3,15 @@ import { Modal } from "@/components/Modal"; import { errorHandlingFetcher } from "@/lib/fetcher"; import { useState } from "react"; -import useSWR, { mutate } from "swr"; -import { Badge, Button, Text, Title } from "@tremor/react"; +import useSWR from "swr"; +import { Button, Callout, Text, Title } from "@tremor/react"; import { ThreeDotsLoader } from "@/components/Loading"; import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces"; import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup"; import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm"; import { LLM_PROVIDERS_ADMIN_URL } from "./constants"; import { CustomLLMProviderUpdateForm } from "./CustomLLMProviderUpdateForm"; +import { ConfiguredLLMProviderDisplay } from "./ConfiguredLLMProviderDisplay"; function LLMProviderUpdateModal({ llmProviderDescriptor, @@ -54,80 +55,29 @@ function LLMProviderUpdateModal({ ); } -function LLMProviderDisplay({ +function DefaultLLMProviderDisplay({ llmProviderDescriptor, - existingLlmProvider, shouldMarkAsDefault, }: { llmProviderDescriptor: WellKnownLLMProviderDescriptor | null; - existingLlmProvider?: FullLLMProvider; shouldMarkAsDefault?: boolean; }) { const [formIsVisible, setFormIsVisible] = useState(false); const { popup, setPopup } = usePopup(); const providerName = - llmProviderDescriptor?.display_name || - llmProviderDescriptor?.name || - existingLlmProvider?.name; + llmProviderDescriptor?.display_name || llmProviderDescriptor?.name; return (
{popup}
{providerName}
- {existingLlmProvider && !existingLlmProvider.is_default_provider && ( -
{ - const response = await fetch( - `${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}/default`, - { - method: "POST", - } - ); - if (!response.ok) { - const errorMsg = (await response.json()).detail; - setPopup({ - type: "error", - message: `Failed to set provider as default: ${errorMsg}`, - }); - return; - } - - mutate(LLM_PROVIDERS_ADMIN_URL); - setPopup({ - type: "success", - message: "Provider set as default successfully!", - }); - }} - > - Set as default -
- )}
- {existingLlmProvider && ( -
- {existingLlmProvider.is_default_provider ? ( - - Default - - ) : ( - - Enabled - - )} -
- )} -
-
@@ -135,7 +85,6 @@ function LLMProviderDisplay({ setFormIsVisible(false)} - existingLlmProvider={existingLlmProvider} shouldMarkAsDefault={shouldMarkAsDefault} setPopup={setPopup} /> @@ -144,7 +93,11 @@ function LLMProviderDisplay({ ); } -function AddCustomLLMProvider({}) { +function AddCustomLLMProvider({ + existingLlmProviders, +}: { + existingLlmProviders: FullLLMProvider[]; +}) { const [formIsVisible, setFormIsVisible] = useState(false); if (formIsVisible) { @@ -156,6 +109,7 @@ function AddCustomLLMProvider({}) {
setFormIsVisible(false)} + shouldMarkAsDefault={existingLlmProviders.length === 0} />
@@ -191,13 +145,32 @@ export function LLMConfiguration() { return ( <> + Enabled LLM Providers + + {existingLlmProviders.length > 0 ? ( + <> + + If multiple LLM providers are enabled, the default provider will be + used for all "Default" Assistants. For user-created + Assistants, you can select the LLM provider/model that best fits the + use case! + + + + ) : ( + + Please set one up below in order to start using Danswer! + + )} + + Add LLM Provider - If multiple LLM providers are enabled, the default provider will be used - for all "Default" Personas. For user-created Personas, you can - select the LLM provider/model that best fits the use case! + Add a new LLM provider by either selecting from one of the default + providers or by specifying your own custom LLM provider. - Default Providers
{llmProviderDescriptors.map((llmProviderDescriptor) => { const existingLlmProvider = existingLlmProviders.find( @@ -205,30 +178,18 @@ export function LLMConfiguration() { ); return ( - ); })}
- Custom Providers - {customLLMProviders.length > 0 && ( -
- {customLLMProviders.map((llmProvider) => ( - - ))} -
- )} - - +
+ +
); } diff --git a/web/src/app/admin/models/llm/LLMProviderUpdateForm.tsx b/web/src/app/admin/models/llm/LLMProviderUpdateForm.tsx index 6dfa863b4..c3340f5e7 100644 --- a/web/src/app/admin/models/llm/LLMProviderUpdateForm.tsx +++ b/web/src/app/admin/models/llm/LLMProviderUpdateForm.tsx @@ -37,6 +37,7 @@ export function LLMProviderUpdateForm({ // Define the initial values based on the provider's requirements const initialValues = { + name: existingLlmProvider?.name ?? "", api_key: existingLlmProvider?.api_key ?? "", api_base: existingLlmProvider?.api_base ?? "", api_version: existingLlmProvider?.api_version ?? "", @@ -64,6 +65,7 @@ export function LLMProviderUpdateForm({ // Setup validation schema if required const validationSchema = Yup.object({ + name: Yup.string().required("Display Name is required"), api_key: llmProviderDescriptor.api_key_required ? Yup.string().required("API Key is required") : Yup.string(), @@ -118,7 +120,7 @@ export function LLMProviderUpdateForm({ "Content-Type": "application/json", }, body: JSON.stringify({ - name: llmProviderDescriptor.name, + provider: llmProviderDescriptor.name, ...values, fast_default_model_name: values.default_fast_model_name || values.default_model_name, @@ -184,6 +186,15 @@ export function LLMProviderUpdateForm({ > {({ values }) => ( + + + + {llmProviderDescriptor.api_key_required && (