Add display names to LLMProvider + allow multiple configs from the same provider

This commit is contained in:
Weves 2024-05-07 15:33:41 -07:00 committed by Chris Weaver
parent d6522426c9
commit 76a5f26fe1
12 changed files with 299 additions and 105 deletions

View File

@ -0,0 +1,45 @@
"""Add user-configured names to LLMProvider
Revision ID: 643a84a42a33
Revises: 0a98909f2757
Create Date: 2024-05-07 14:54:55.493100
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "643a84a42a33"
down_revision = "0a98909f2757"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column("llm_provider", sa.Column("provider", sa.String(), nullable=True))
# move "name" -> "provider" to match the new schema
op.execute("UPDATE llm_provider SET provider = name")
# pretty up display name
op.execute("UPDATE llm_provider SET name = 'OpenAI' WHERE name = 'openai'")
op.execute("UPDATE llm_provider SET name = 'Anthropic' WHERE name = 'anthropic'")
op.execute("UPDATE llm_provider SET name = 'Azure OpenAI' WHERE name = 'azure'")
op.execute("UPDATE llm_provider SET name = 'AWS Bedrock' WHERE name = 'bedrock'")
# update personas to use the new provider names
op.execute(
"UPDATE persona SET llm_model_provider_override = 'OpenAI' WHERE llm_model_provider_override = 'openai'"
)
op.execute(
"UPDATE persona SET llm_model_provider_override = 'Anthropic' WHERE llm_model_provider_override = 'anthropic'"
)
op.execute(
"UPDATE persona SET llm_model_provider_override = 'Azure OpenAI' WHERE llm_model_provider_override = 'azure'"
)
op.execute(
"UPDATE persona SET llm_model_provider_override = 'AWS Bedrock' WHERE llm_model_provider_override = 'bedrock'"
)
def downgrade() -> None:
op.execute("UPDATE llm_provider SET name = provider")
op.drop_column("llm_provider", "provider")

View File

@ -14,6 +14,7 @@ def upsert_llm_provider(
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
)
if existing_llm_provider:
existing_llm_provider.provider = llm_provider.provider
existing_llm_provider.api_key = llm_provider.api_key
existing_llm_provider.api_base = llm_provider.api_base
existing_llm_provider.api_version = llm_provider.api_version
@ -29,6 +30,7 @@ def upsert_llm_provider(
# if it does not exist, create a new entry
llm_provider_model = LLMProviderModel(
name=llm_provider.name,
provider=llm_provider.provider,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,

View File

@ -779,6 +779,7 @@ class LLMProvider(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
provider: Mapped[str] = mapped_column(String)
api_key: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
api_base: Mapped[str | None] = mapped_column(String, nullable=True)
api_version: Mapped[str | None] = mapped_column(String, nullable=True)

View File

@ -94,6 +94,7 @@ def port_api_key_to_postgres() -> None:
llm_provider_upsert = LLMProviderUpsertRequest(
name=GEN_AI_MODEL_PROVIDER,
provider=GEN_AI_MODEL_PROVIDER,
api_key=api_key,
api_base=GEN_AI_API_ENDPOINT,
api_version=GEN_AI_API_VERSION,

View File

@ -20,11 +20,10 @@ def get_llm_for_persona(
temperature_override = llm_override.temperature if llm_override else None
return get_default_llm(
gen_ai_model_provider=model_provider_override
or persona.llm_model_provider_override,
gen_ai_model_version_override=(
model_version_override or persona.llm_model_version_override
model_provider_name=(
model_provider_override or persona.llm_model_provider_override
),
model_version=(model_version_override or persona.llm_model_version_override),
temperature=temperature_override or GEN_AI_TEMPERATURE,
)
@ -33,23 +32,23 @@ def get_default_llm(
timeout: int = QA_TIMEOUT,
temperature: float = GEN_AI_TEMPERATURE,
use_fast_llm: bool = False,
gen_ai_model_provider: str | None = None,
gen_ai_model_version_override: str | None = None,
model_provider_name: str | None = None,
model_version: str | None = None,
) -> LLM:
if DISABLE_GENERATIVE_AI:
raise GenAIDisabledException()
# TODO: pass this in
with get_session_context_manager() as session:
if gen_ai_model_provider is None:
if model_provider_name is None:
llm_provider = fetch_default_provider(session)
else:
llm_provider = fetch_provider(session, gen_ai_model_provider)
llm_provider = fetch_provider(session, model_provider_name)
if not llm_provider:
raise ValueError("No default LLM provider found")
model_name = gen_ai_model_version_override or (
model_name = model_version or (
(llm_provider.fast_default_model_name or llm_provider.default_model_name)
if use_fast_llm
else llm_provider.default_model_name
@ -58,7 +57,7 @@ def get_default_llm(
raise ValueError("No default model name found")
return get_llm(
provider=llm_provider.name,
provider=llm_provider.provider,
model=model_name,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,

View File

@ -26,6 +26,7 @@ class LLMProviderDescriptor(BaseModel):
non-admin users. Used when giving a list of available LLMs."""
name: str
provider: str
model_names: list[str]
default_model_name: str
fast_default_model_name: str | None
@ -37,12 +38,13 @@ class LLMProviderDescriptor(BaseModel):
) -> "LLMProviderDescriptor":
return cls(
name=llm_provider_model.name,
provider=llm_provider_model.provider,
default_model_name=llm_provider_model.default_model_name,
fast_default_model_name=llm_provider_model.fast_default_model_name,
is_default_provider=llm_provider_model.is_default_provider,
model_names=(
llm_provider_model.model_names
or fetch_models_for_provider(llm_provider_model.name)
or fetch_models_for_provider(llm_provider_model.provider)
or [llm_provider_model.default_model_name]
),
)
@ -50,6 +52,7 @@ class LLMProviderDescriptor(BaseModel):
class LLMProvider(BaseModel):
name: str
provider: str
api_key: str | None
api_base: str | None
api_version: str | None
@ -74,6 +77,7 @@ class FullLLMProvider(LLMProvider):
return cls(
id=llm_provider_model.id,
name=llm_provider_model.name,
provider=llm_provider_model.provider,
api_key=llm_provider_model.api_key,
api_base=llm_provider_model.api_base,
api_version=llm_provider_model.api_version,

View File

@ -34,13 +34,6 @@ import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelec
import { FullLLMProvider } from "../models/llm/interfaces";
import { Option } from "@/components/Dropdown";
const DEFAULT_LLM_PROVIDER_TO_DISPLAY_NAME: Record<string, string> = {
openai: "OpenAI",
azure: "Azure OpenAI",
anthropic: "Anthropic",
bedrock: "AWS Bedrock",
};
function Label({ children }: { children: string | JSX.Element }) {
return (
<div className="block font-medium text-base text-emphasis">{children}</div>
@ -495,10 +488,7 @@ export function AssistantEditor({
<SelectorFormField
name="llm_model_provider_override"
options={llmProviders.map((llmProvider) => ({
name:
DEFAULT_LLM_PROVIDER_TO_DISPLAY_NAME[
llmProvider.name
] || llmProvider.name,
name: llmProvider.name,
value: llmProvider.name,
}))}
includeDefault={true}

View File

@ -0,0 +1,170 @@
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
import { Modal } from "@/components/Modal";
import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm";
import { CustomLLMProviderUpdateForm } from "./CustomLLMProviderUpdateForm";
import { useState } from "react";
import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
import { mutate } from "swr";
import { Badge, Button } from "@tremor/react";
function LLMProviderUpdateModal({
llmProviderDescriptor,
onClose,
existingLlmProvider,
shouldMarkAsDefault,
setPopup,
}: {
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null;
onClose: () => void;
existingLlmProvider?: FullLLMProvider;
shouldMarkAsDefault?: boolean;
setPopup?: (popup: PopupSpec) => void;
}) {
const providerName =
llmProviderDescriptor?.display_name ||
llmProviderDescriptor?.name ||
existingLlmProvider?.name ||
"Custom LLM Provider";
return (
<Modal title={`Setup ${providerName}`} onOutsideClick={() => onClose()}>
<div className="max-h-[70vh] overflow-y-auto px-4">
{llmProviderDescriptor ? (
<LLMProviderUpdateForm
llmProviderDescriptor={llmProviderDescriptor}
onClose={onClose}
existingLlmProvider={existingLlmProvider}
shouldMarkAsDefault={shouldMarkAsDefault}
setPopup={setPopup}
/>
) : (
<CustomLLMProviderUpdateForm
onClose={onClose}
existingLlmProvider={existingLlmProvider}
shouldMarkAsDefault={shouldMarkAsDefault}
setPopup={setPopup}
/>
)}
</div>
</Modal>
);
}
function LLMProviderDisplay({
llmProviderDescriptor,
existingLlmProvider,
shouldMarkAsDefault,
}: {
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null;
existingLlmProvider: FullLLMProvider;
shouldMarkAsDefault?: boolean;
}) {
const [formIsVisible, setFormIsVisible] = useState(false);
const { popup, setPopup } = usePopup();
const providerName =
llmProviderDescriptor?.display_name ||
llmProviderDescriptor?.name ||
existingLlmProvider?.name;
return (
<div>
{popup}
<div className="border border-border p-3 rounded w-96 flex shadow-md">
<div className="my-auto">
<div className="font-bold">{providerName} </div>
<div className="text-xs italic">({existingLlmProvider.provider})</div>
{!existingLlmProvider.is_default_provider && (
<div
className="text-xs text-link cursor-pointer pt-1"
onClick={async () => {
const response = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}/default`,
{
method: "POST",
}
);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
setPopup({
type: "error",
message: `Failed to set provider as default: ${errorMsg}`,
});
return;
}
mutate(LLM_PROVIDERS_ADMIN_URL);
setPopup({
type: "success",
message: "Provider set as default successfully!",
});
}}
>
Set as default
</div>
)}
</div>
{existingLlmProvider && (
<div className="my-auto ml-3">
{existingLlmProvider.is_default_provider ? (
<Badge color="orange" size="xs">
Default
</Badge>
) : (
<Badge color="green" size="xs">
Enabled
</Badge>
)}
</div>
)}
<div className="ml-auto">
<Button
color={existingLlmProvider ? "green" : "blue"}
size="xs"
onClick={() => setFormIsVisible(true)}
>
{existingLlmProvider ? "Edit" : "Set up"}
</Button>
</div>
</div>
{formIsVisible && (
<LLMProviderUpdateModal
llmProviderDescriptor={llmProviderDescriptor}
onClose={() => setFormIsVisible(false)}
existingLlmProvider={existingLlmProvider}
shouldMarkAsDefault={shouldMarkAsDefault}
setPopup={setPopup}
/>
)}
</div>
);
}
export function ConfiguredLLMProviderDisplay({
existingLlmProviders,
}: {
existingLlmProviders: FullLLMProvider[];
}) {
existingLlmProviders = existingLlmProviders.sort((a, b) => {
if (a.is_default_provider && !b.is_default_provider) {
return -1;
}
if (!a.is_default_provider && b.is_default_provider) {
return 1;
}
return a.provider > b.provider ? 1 : -1;
});
return (
<div className="gap-y-4 flex flex-col">
{existingLlmProviders.map((provider) => (
<LLMProviderDisplay
key={provider.id}
llmProviderDescriptor={null}
existingLlmProvider={provider}
/>
))}
</div>
);
}

View File

@ -53,6 +53,7 @@ export function CustomLLMProviderUpdateForm({
// Define the initial values based on the provider's requirements
const initialValues = {
name: existingLlmProvider?.name ?? "",
provider: existingLlmProvider?.provider ?? "",
api_key: existingLlmProvider?.api_key ?? "",
api_base: existingLlmProvider?.api_base ?? "",
api_version: existingLlmProvider?.api_version ?? "",
@ -71,7 +72,8 @@ export function CustomLLMProviderUpdateForm({
// Setup validation schema if required
const validationSchema = Yup.object({
name: Yup.string().required("Name is required"),
name: Yup.string().required("Display Name is required"),
provider: Yup.string().required("Provider Name is required"),
api_key: Yup.string(),
api_base: Yup.string(),
api_version: Yup.string(),
@ -185,6 +187,15 @@ export function CustomLLMProviderUpdateForm({
<Form>
<TextFormField
name="name"
label="Display Name"
subtext="A name which you can use to identify this provider when selecting it in the UI."
placeholder="Display Name"
/>
<Divider />
<TextFormField
name="provider"
label="Provider Name"
subtext={
<>
@ -384,7 +395,6 @@ export function CustomLLMProviderUpdateForm({
disabled={isTesting}
onClick={async () => {
setIsTesting(true);
console.log(values.custom_config_list);
const response = await fetch("/api/admin/llm/test", {
method: "POST",
@ -392,7 +402,6 @@ export function CustomLLMProviderUpdateForm({
"Content-Type": "application/json",
},
body: JSON.stringify({
provider: values.name,
custom_config: customConfigProcessing(
values.custom_config_list
),

View File

@ -3,14 +3,15 @@
import { Modal } from "@/components/Modal";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { useState } from "react";
import useSWR, { mutate } from "swr";
import { Badge, Button, Text, Title } from "@tremor/react";
import useSWR from "swr";
import { Button, Callout, Text, Title } from "@tremor/react";
import { ThreeDotsLoader } from "@/components/Loading";
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm";
import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
import { CustomLLMProviderUpdateForm } from "./CustomLLMProviderUpdateForm";
import { ConfiguredLLMProviderDisplay } from "./ConfiguredLLMProviderDisplay";
function LLMProviderUpdateModal({
llmProviderDescriptor,
@ -54,80 +55,29 @@ function LLMProviderUpdateModal({
);
}
function LLMProviderDisplay({
function DefaultLLMProviderDisplay({
llmProviderDescriptor,
existingLlmProvider,
shouldMarkAsDefault,
}: {
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null;
existingLlmProvider?: FullLLMProvider;
shouldMarkAsDefault?: boolean;
}) {
const [formIsVisible, setFormIsVisible] = useState(false);
const { popup, setPopup } = usePopup();
const providerName =
llmProviderDescriptor?.display_name ||
llmProviderDescriptor?.name ||
existingLlmProvider?.name;
llmProviderDescriptor?.display_name || llmProviderDescriptor?.name;
return (
<div>
{popup}
<div className="border border-border p-3 rounded w-96 flex shadow-md">
<div className="my-auto">
<div className="font-bold">{providerName} </div>
{existingLlmProvider && !existingLlmProvider.is_default_provider && (
<div
className="text-xs text-link cursor-pointer"
onClick={async () => {
const response = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}/default`,
{
method: "POST",
}
);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
setPopup({
type: "error",
message: `Failed to set provider as default: ${errorMsg}`,
});
return;
}
mutate(LLM_PROVIDERS_ADMIN_URL);
setPopup({
type: "success",
message: "Provider set as default successfully!",
});
}}
>
Set as default
</div>
)}
</div>
{existingLlmProvider && (
<div className="my-auto">
{existingLlmProvider.is_default_provider ? (
<Badge color="orange" className="ml-2" size="xs">
Default
</Badge>
) : (
<Badge color="green" className="ml-2" size="xs">
Enabled
</Badge>
)}
</div>
)}
<div className="ml-auto">
<Button
color={existingLlmProvider ? "green" : "blue"}
size="xs"
onClick={() => setFormIsVisible(true)}
>
{existingLlmProvider ? "Edit" : "Set up"}
<Button color="blue" size="xs" onClick={() => setFormIsVisible(true)}>
Set up
</Button>
</div>
</div>
@ -135,7 +85,6 @@ function LLMProviderDisplay({
<LLMProviderUpdateModal
llmProviderDescriptor={llmProviderDescriptor}
onClose={() => setFormIsVisible(false)}
existingLlmProvider={existingLlmProvider}
shouldMarkAsDefault={shouldMarkAsDefault}
setPopup={setPopup}
/>
@ -144,7 +93,11 @@ function LLMProviderDisplay({
);
}
function AddCustomLLMProvider({}) {
function AddCustomLLMProvider({
existingLlmProviders,
}: {
existingLlmProviders: FullLLMProvider[];
}) {
const [formIsVisible, setFormIsVisible] = useState(false);
if (formIsVisible) {
@ -156,6 +109,7 @@ function AddCustomLLMProvider({}) {
<div className="max-h-[70vh] overflow-y-auto px-4">
<CustomLLMProviderUpdateForm
onClose={() => setFormIsVisible(false)}
shouldMarkAsDefault={existingLlmProviders.length === 0}
/>
</div>
</Modal>
@ -191,13 +145,32 @@ export function LLMConfiguration() {
return (
<>
<Title className="mb-2">Enabled LLM Providers</Title>
{existingLlmProviders.length > 0 ? (
<>
<Text className="mb-4">
If multiple LLM providers are enabled, the default provider will be
used for all &quot;Default&quot; Assistants. For user-created
Assistants, you can select the LLM provider/model that best fits the
use case!
</Text>
<ConfiguredLLMProviderDisplay
existingLlmProviders={existingLlmProviders}
/>
</>
) : (
<Callout title="No LLM providers configured yet" color="yellow">
Please set one up below in order to start using Danswer!
</Callout>
)}
<Title className="mb-2 mt-6">Add LLM Provider</Title>
<Text className="mb-4">
If multiple LLM providers are enabled, the default provider will be used
for all &quot;Default&quot; Personas. For user-created Personas, you can
select the LLM provider/model that best fits the use case!
Add a new LLM provider by either selecting from one of the default
providers or by specifying your own custom LLM provider.
</Text>
<Title className="mb-2">Default Providers</Title>
<div className="gap-y-4 flex flex-col">
{llmProviderDescriptors.map((llmProviderDescriptor) => {
const existingLlmProvider = existingLlmProviders.find(
@ -205,30 +178,18 @@ export function LLMConfiguration() {
);
return (
<LLMProviderDisplay
<DefaultLLMProviderDisplay
key={llmProviderDescriptor.name}
llmProviderDescriptor={llmProviderDescriptor}
existingLlmProvider={existingLlmProvider}
shouldMarkAsDefault={existingLlmProviders.length === 0}
/>
);
})}
</div>
<Title className="mb-2 mt-4">Custom Providers</Title>
{customLLMProviders.length > 0 && (
<div className="gap-y-4 flex flex-col mb-4">
{customLLMProviders.map((llmProvider) => (
<LLMProviderDisplay
key={llmProvider.id}
llmProviderDescriptor={null}
existingLlmProvider={llmProvider}
/>
))}
</div>
)}
<AddCustomLLMProvider />
<div className="mt-4">
<AddCustomLLMProvider existingLlmProviders={existingLlmProviders} />
</div>
</>
);
}

View File

@ -37,6 +37,7 @@ export function LLMProviderUpdateForm({
// Define the initial values based on the provider's requirements
const initialValues = {
name: existingLlmProvider?.name ?? "",
api_key: existingLlmProvider?.api_key ?? "",
api_base: existingLlmProvider?.api_base ?? "",
api_version: existingLlmProvider?.api_version ?? "",
@ -64,6 +65,7 @@ export function LLMProviderUpdateForm({
// Setup validation schema if required
const validationSchema = Yup.object({
name: Yup.string().required("Display Name is required"),
api_key: llmProviderDescriptor.api_key_required
? Yup.string().required("API Key is required")
: Yup.string(),
@ -118,7 +120,7 @@ export function LLMProviderUpdateForm({
"Content-Type": "application/json",
},
body: JSON.stringify({
name: llmProviderDescriptor.name,
provider: llmProviderDescriptor.name,
...values,
fast_default_model_name:
values.default_fast_model_name || values.default_model_name,
@ -184,6 +186,15 @@ export function LLMProviderUpdateForm({
>
{({ values }) => (
<Form>
<TextFormField
name="name"
label="Display Name"
subtext="A name which you can use to identify this provider when selecting it in the UI."
placeholder="Display Name"
/>
<Divider />
{llmProviderDescriptor.api_key_required && (
<TextFormField
name="api_key"

View File

@ -21,6 +21,7 @@ export interface WellKnownLLMProviderDescriptor {
export interface LLMProvider {
name: string;
provider: string;
api_key: string | null;
api_base: string | null;
api_version: string | null;