Added ability to control LLM access based on group (#1870)

* Added ability to control LLM access based on group

* completed relationship deletion

* cleaned up function

* added comments

* fixed frontend strings

* mypy fixes

* added case handling for deletion of user groups

* hidden advanced options now

* removed unnecessary code
This commit is contained in:
hagen-danswer 2024-07-21 21:31:44 -07:00 committed by GitHub
parent 2f5f19642e
commit 1b49d17239
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 558 additions and 221 deletions

View File

@ -0,0 +1,41 @@
"""add_llm_group_permissions_control
Revision ID: 795b20b85b4b
Revises: 05c07bf07c00
Create Date: 2024-07-19 11:54:35.701558
"""
from alembic import op
import sqlalchemy as sa
revision = "795b20b85b4b"
down_revision = "05c07bf07c00"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"llm_provider__user_group",
sa.Column("llm_provider_id", sa.Integer(), nullable=False),
sa.Column("user_group_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["llm_provider_id"],
["llm_provider.id"],
),
sa.ForeignKeyConstraint(
["user_group_id"],
["user_group.id"],
),
sa.PrimaryKeyConstraint("llm_provider_id", "user_group_id"),
)
op.add_column(
"llm_provider",
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="true"),
)
def downgrade() -> None:
op.drop_table("llm_provider__user_group")
op.drop_column("llm_provider", "is_public")

View File

@ -1,15 +1,42 @@
from sqlalchemy import delete
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from danswer.db.models import LLMProvider as LLMProviderModel
from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from danswer.server.manage.llm.models import FullLLMProvider
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
def update_group_llm_provider_relationships(
llm_provider_id: int,
group_ids: list[int] | None,
db_session: Session,
) -> None:
# Delete existing relationships
db_session.query(LLMProvider__UserGroup).filter(
LLMProvider__UserGroup.llm_provider_id == llm_provider_id
).delete(synchronize_session="fetch")
# Add new relationships from given group_ids
if group_ids:
new_relationships = [
LLMProvider__UserGroup(
llm_provider_id=llm_provider_id,
user_group_id=group_id,
)
for group_id in group_ids
]
db_session.add_all(new_relationships)
db_session.commit()
def upsert_cloud_embedding_provider(
db_session: Session, provider: CloudEmbeddingProviderCreationRequest
) -> CloudEmbeddingProvider:
@ -36,36 +63,33 @@ def upsert_llm_provider(
existing_llm_provider = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
)
if existing_llm_provider:
existing_llm_provider.provider = llm_provider.provider
existing_llm_provider.api_key = llm_provider.api_key
existing_llm_provider.api_base = llm_provider.api_base
existing_llm_provider.api_version = llm_provider.api_version
existing_llm_provider.custom_config = llm_provider.custom_config
existing_llm_provider.default_model_name = llm_provider.default_model_name
existing_llm_provider.fast_default_model_name = (
llm_provider.fast_default_model_name
)
existing_llm_provider.model_names = llm_provider.model_names
db_session.commit()
return FullLLMProvider.from_model(existing_llm_provider)
# if it does not exist, create a new entry
llm_provider_model = LLMProviderModel(
name=llm_provider.name,
provider=llm_provider.provider,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
default_model_name=llm_provider.default_model_name,
fast_default_model_name=llm_provider.fast_default_model_name,
model_names=llm_provider.model_names,
is_default_provider=None,
)
db_session.add(llm_provider_model)
db_session.commit()
return FullLLMProvider.from_model(llm_provider_model)
if not existing_llm_provider:
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
db_session.add(existing_llm_provider)
existing_llm_provider.provider = llm_provider.provider
existing_llm_provider.api_key = llm_provider.api_key
existing_llm_provider.api_base = llm_provider.api_base
existing_llm_provider.api_version = llm_provider.api_version
existing_llm_provider.custom_config = llm_provider.custom_config
existing_llm_provider.default_model_name = llm_provider.default_model_name
existing_llm_provider.fast_default_model_name = llm_provider.fast_default_model_name
existing_llm_provider.model_names = llm_provider.model_names
existing_llm_provider.is_public = llm_provider.is_public
if not existing_llm_provider.id:
# If its not already in the db, we need to generate an ID by flushing
db_session.flush()
# Make sure the relationship table stays up to date
update_group_llm_provider_relationships(
llm_provider_id=existing_llm_provider.id,
group_ids=llm_provider.groups,
db_session=db_session,
)
return FullLLMProvider.from_model(existing_llm_provider)
def fetch_existing_embedding_providers(
@ -74,8 +98,29 @@ def fetch_existing_embedding_providers(
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
return list(db_session.scalars(select(LLMProviderModel)).all())
def fetch_existing_llm_providers(
db_session: Session,
user: User | None = None,
) -> list[LLMProviderModel]:
if not user:
return list(db_session.scalars(select(LLMProviderModel)).all())
stmt = select(LLMProviderModel).distinct()
user_groups_subquery = (
select(User__UserGroup.user_group_id)
.where(User__UserGroup.user_id == user.id)
.subquery()
)
access_conditions = or_(
LLMProviderModel.is_public,
LLMProviderModel.id.in_( # User is part of a group that has access
select(LLMProvider__UserGroup.llm_provider_id).where(
LLMProvider__UserGroup.user_group_id.in_(user_groups_subquery) # type: ignore
)
),
)
stmt = stmt.where(access_conditions)
return list(db_session.scalars(stmt).all())
def fetch_embedding_provider(
@ -119,6 +164,13 @@ def remove_embedding_provider(
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
# Remove LLMProvider's dependent relationships
db_session.execute(
delete(LLMProvider__UserGroup).where(
LLMProvider__UserGroup.llm_provider_id == provider_id
)
)
# Remove LLMProvider
db_session.execute(
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
)

View File

@ -932,6 +932,13 @@ class LLMProvider(Base):
# should only be set for a single provider
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
# EE only
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
groups: Mapped[list["UserGroup"]] = relationship(
"UserGroup",
secondary="llm_provider__user_group",
viewonly=True,
)
class CloudEmbeddingProvider(Base):
@ -1109,7 +1116,6 @@ class Persona(Base):
# where lower value IDs (e.g. created earlier) are displayed first
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=None)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
# These are only defaults, users can select from all if desired
prompts: Mapped[list[Prompt]] = relationship(
@ -1137,6 +1143,7 @@ class Persona(Base):
viewonly=True,
)
# EE only
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
groups: Mapped[list["UserGroup"]] = relationship(
"UserGroup",
secondary="persona__user_group",
@ -1360,6 +1367,17 @@ class Persona__UserGroup(Base):
)
class LLMProvider__UserGroup(Base):
__tablename__ = "llm_provider__user_group"
llm_provider_id: Mapped[int] = mapped_column(
ForeignKey("llm_provider.id"), primary_key=True
)
user_group_id: Mapped[int] = mapped_column(
ForeignKey("user_group.id"), primary_key=True
)
class DocumentSet__UserGroup(Base):
__tablename__ = "document_set__user_group"

View File

@ -103,6 +103,7 @@ def port_api_key_to_postgres() -> None:
default_model_name=default_model_name,
fast_default_model_name=default_fast_model_name,
model_names=None,
is_public=True,
)
llm_provider = upsert_llm_provider(db_session, llm_provider_upsert)
update_default_provider(db_session, llm_provider.id)

View File

@ -70,6 +70,7 @@ def load_llm_providers(db_session: Session) -> None:
FAST_GEN_AI_MODEL_VERSION or well_known_provider.default_fast_model
),
model_names=model_names,
is_public=True,
)
llm_provider = upsert_llm_provider(db_session, llm_provider_request)
update_default_provider(db_session, llm_provider.id)

View File

@ -147,10 +147,10 @@ def set_provider_as_default(
@basic_router.get("/provider")
def list_llm_provider_basics(
_: User | None = Depends(current_user),
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
return [
LLMProviderDescriptor.from_model(llm_provider_model)
for llm_provider_model in fetch_existing_llm_providers(db_session)
for llm_provider_model in fetch_existing_llm_providers(db_session, user)
]

View File

@ -60,6 +60,8 @@ class LLMProvider(BaseModel):
custom_config: dict[str, str] | None
default_model_name: str
fast_default_model_name: str | None
is_public: bool
groups: list[int] | None = None
class LLMProviderUpsertRequest(LLMProvider):
@ -91,4 +93,6 @@ class FullLLMProvider(LLMProvider):
or fetch_models_for_provider(llm_provider_model.provider)
or [llm_provider_model.default_model_name]
),
is_public=llm_provider_model.is_public,
groups=[group.id for group in llm_provider_model.groups],
)

View File

@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Document
from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import TokenRateLimit__UserGroup
from danswer.db.models import User
from danswer.db.models import User__UserGroup
@ -194,6 +195,15 @@ def _cleanup_user__user_group_relationships__no_commit(
db_session.delete(user__user_group_relationship)
def _cleanup_llm_provider__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(LLMProvider__UserGroup).filter(
LLMProvider__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)
def _mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session: Session, user_group_id: int
) -> None:
@ -316,6 +326,9 @@ def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> Non
def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
_cleanup_llm_provider__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group.id
)
_cleanup_user__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group.id
)

View File

@ -1,5 +1,6 @@
import { LoadingAnimation } from "@/components/Loading";
import { Button, Divider, Text } from "@tremor/react";
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
import {
ArrayHelpers,
ErrorMessage,
@ -15,11 +16,16 @@ import {
SubLabel,
TextArrayField,
TextFormField,
BooleanFormField,
} from "@/components/admin/connectors/Field";
import { useState } from "react";
import { Bubble } from "@/components/Bubble";
import { GroupsIcon } from "@/components/icons/icons";
import { useSWRConfig } from "swr";
import { useUserGroups } from "@/lib/hooks";
import { FullLLMProvider } from "./interfaces";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import * as Yup from "yup";
import isEqual from "lodash/isEqual";
@ -44,9 +50,16 @@ export function CustomLLMProviderUpdateForm({
}) {
const { mutate } = useSWRConfig();
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
// EE only
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
const [isTesting, setIsTesting] = useState(false);
const [testError, setTestError] = useState<string>("");
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
// Define the initial values based on the provider's requirements
const initialValues = {
name: existingLlmProvider?.name ?? "",
@ -61,6 +74,8 @@ export function CustomLLMProviderUpdateForm({
custom_config_list: existingLlmProvider?.custom_config
? Object.entries(existingLlmProvider.custom_config)
: [],
is_public: existingLlmProvider?.is_public ?? true,
groups: existingLlmProvider?.groups ?? [],
};
// Setup validation schema if required
@ -74,6 +89,9 @@ export function CustomLLMProviderUpdateForm({
default_model_name: Yup.string().required("Model name is required"),
fast_default_model_name: Yup.string().nullable(),
custom_config_list: Yup.array(),
// EE Only
is_public: Yup.boolean().required(),
groups: Yup.array().of(Yup.number()),
});
return (
@ -97,6 +115,9 @@ export function CustomLLMProviderUpdateForm({
return;
}
// don't set groups if marked as public
const groups = values.is_public ? [] : values.groups;
// test the configuration
if (!isEqual(values, initialValues)) {
setIsTesting(true);
@ -188,93 +209,97 @@ export function CustomLLMProviderUpdateForm({
setSubmitting(false);
}}
>
{({ values }) => (
<Form>
<TextFormField
name="name"
label="Display Name"
subtext="A name which you can use to identify this provider when selecting it in the UI."
placeholder="Display Name"
/>
{({ values, setFieldValue }) => {
return (
<Form>
<TextFormField
name="name"
label="Display Name"
subtext="A name which you can use to identify this provider when selecting it in the UI."
placeholder="Display Name"
/>
<Divider />
<Divider />
<TextFormField
name="provider"
label="Provider Name"
subtext={
<TextFormField
name="provider"
label="Provider Name"
subtext={
<>
Should be one of the providers listed at{" "}
<a
target="_blank"
href="https://docs.litellm.ai/docs/providers"
className="text-link"
>
https://docs.litellm.ai/docs/providers
</a>
.
</>
}
placeholder="Name of the custom provider"
/>
<Divider />
<SubLabel>
Fill in the following as is needed. Refer to the LiteLLM
documentation for the model provider name specified above in order
to determine which fields are required.
</SubLabel>
<TextFormField
name="api_key"
label="[Optional] API Key"
placeholder="API Key"
type="password"
/>
<TextFormField
name="api_base"
label="[Optional] API Base"
placeholder="API Base"
/>
<TextFormField
name="api_version"
label="[Optional] API Version"
placeholder="API Version"
/>
<Label>[Optional] Custom Configs</Label>
<SubLabel>
<>
Should be one of the providers listed at{" "}
<a
target="_blank"
href="https://docs.litellm.ai/docs/providers"
className="text-link"
>
https://docs.litellm.ai/docs/providers
</a>
.
<div>
Additional configurations needed by the model provider. Are
passed to litellm via environment variables.
</div>
<div className="mt-2">
For example, when configuring the Cloudflare provider, you
would need to set `CLOUDFLARE_ACCOUNT_ID` as the key and your
Cloudflare account ID as the value.
</div>
</>
}
placeholder="Name of the custom provider"
/>
</SubLabel>
<Divider />
<SubLabel>
Fill in the following as is needed. Refer to the LiteLLM
documentation for the model provider name specified above in order
to determine which fields are required.
</SubLabel>
<TextFormField
name="api_key"
label="[Optional] API Key"
placeholder="API Key"
type="password"
/>
<TextFormField
name="api_base"
label="[Optional] API Base"
placeholder="API Base"
/>
<TextFormField
name="api_version"
label="[Optional] API Version"
placeholder="API Version"
/>
<Label>[Optional] Custom Configs</Label>
<SubLabel>
<>
<div>
Additional configurations needed by the model provider. Are
passed to litellm via environment variables.
</div>
<div className="mt-2">
For example, when configuring the Cloudflare provider, you would
need to set `CLOUDFLARE_ACCOUNT_ID` as the key and your
Cloudflare account ID as the value.
</div>
</>
</SubLabel>
<FieldArray
name="custom_config_list"
render={(arrayHelpers: ArrayHelpers<any[]>) => (
<div>
{values.custom_config_list.map((_, index) => {
return (
<div key={index} className={index === 0 ? "mt-2" : "mt-6"}>
<div className="flex">
<div className="w-full mr-6 border border-border p-3 rounded">
<div>
<Label>Key</Label>
<Field
name={`custom_config_list[${index}][0]`}
className={`
<FieldArray
name="custom_config_list"
render={(arrayHelpers: ArrayHelpers<any[]>) => (
<div>
{values.custom_config_list.map((_, index) => {
return (
<div
key={index}
className={index === 0 ? "mt-2" : "mt-6"}
>
<div className="flex">
<div className="w-full mr-6 border border-border p-3 rounded">
<div>
<Label>Key</Label>
<Field
name={`custom_config_list[${index}][0]`}
className={`
border
border-border
bg-background
@ -284,20 +309,20 @@ export function CustomLLMProviderUpdateForm({
px-3
mr-4
`}
autoComplete="off"
/>
<ErrorMessage
name={`custom_config_list[${index}][0]`}
component="div"
className="text-error text-sm mt-1"
/>
</div>
autoComplete="off"
/>
<ErrorMessage
name={`custom_config_list[${index}][0]`}
component="div"
className="text-error text-sm mt-1"
/>
</div>
<div className="mt-3">
<Label>Value</Label>
<Field
name={`custom_config_list[${index}][1]`}
className={`
<div className="mt-3">
<Label>Value</Label>
<Field
name={`custom_config_list[${index}][1]`}
className={`
border
border-border
bg-background
@ -307,121 +332,190 @@ export function CustomLLMProviderUpdateForm({
px-3
mr-4
`}
autoComplete="off"
/>
<ErrorMessage
name={`custom_config_list[${index}][1]`}
component="div"
className="text-error text-sm mt-1"
autoComplete="off"
/>
<ErrorMessage
name={`custom_config_list[${index}][1]`}
component="div"
className="text-error text-sm mt-1"
/>
</div>
</div>
<div className="my-auto">
<FiX
className="my-auto w-10 h-10 cursor-pointer hover:bg-hover rounded p-2"
onClick={() => arrayHelpers.remove(index)}
/>
</div>
</div>
<div className="my-auto">
<FiX
className="my-auto w-10 h-10 cursor-pointer hover:bg-hover rounded p-2"
onClick={() => arrayHelpers.remove(index)}
/>
</div>
</div>
</div>
);
})}
);
})}
<Button
onClick={() => {
arrayHelpers.push(["", ""]);
}}
className="mt-3"
color="green"
size="xs"
type="button"
icon={FiPlus}
>
Add New
</Button>
</div>
)}
/>
<Button
onClick={() => {
arrayHelpers.push(["", ""]);
}}
className="mt-3"
color="green"
size="xs"
type="button"
icon={FiPlus}
>
Add New
</Button>
</div>
)}
/>
<Divider />
<Divider />
<TextArrayField
name="model_names"
label="Model Names"
values={values}
subtext={`List the individual models that you want to make
<TextArrayField
name="model_names"
label="Model Names"
values={values}
subtext={`List the individual models that you want to make
available as a part of this provider. At least one must be specified.
As an example, for OpenAI one model might be "gpt-4".`}
/>
/>
<Divider />
<Divider />
<TextFormField
name="default_model_name"
subtext={`
<TextFormField
name="default_model_name"
subtext={`
The model to use by default for this provider unless
otherwise specified. Must be one of the models listed
above.`}
label="Default Model"
placeholder="E.g. gpt-4"
/>
label="Default Model"
placeholder="E.g. gpt-4"
/>
<TextFormField
name="fast_default_model_name"
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
<TextFormField
name="fast_default_model_name"
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
for this provider. If not set, will use
the Default Model configured above.`}
label="[Optional] Fast Model"
placeholder="E.g. gpt-4"
/>
label="[Optional] Fast Model"
placeholder="E.g. gpt-4"
/>
<Divider />
<Divider />
<div>
{/* NOTE: this is above the test button to make sure it's visible */}
{testError && <Text className="text-error mt-2">{testError}</Text>}
<AdvancedOptionsToggle
showAdvancedOptions={showAdvancedOptions}
setShowAdvancedOptions={setShowAdvancedOptions}
/>
<div className="flex w-full mt-4">
<Button type="submit" size="xs">
{isTesting ? (
<LoadingAnimation text="Testing" />
) : existingLlmProvider ? (
"Update"
) : (
"Enable"
{showAdvancedOptions && (
<>
{isPaidEnterpriseFeaturesEnabled && userGroups && (
<>
<BooleanFormField
small
noPadding
alignTop
name="is_public"
label="Is Public?"
subtext="If set, this LLM Provider will be available to all users. If not, only the specified User Groups will be able to use it."
/>
{userGroups &&
userGroups.length > 0 &&
!values.is_public && (
<div>
<Text>
Select which User Groups should have access to this
LLM Provider.
</Text>
<div className="flex flex-wrap gap-2 mt-2">
{userGroups.map((userGroup) => {
const isSelected = values.groups.includes(
userGroup.id
);
return (
<Bubble
key={userGroup.id}
isSelected={isSelected}
onClick={() => {
if (isSelected) {
setFieldValue(
"groups",
values.groups.filter(
(id) => id !== userGroup.id
)
);
} else {
setFieldValue("groups", [
...values.groups,
userGroup.id,
]);
}
}}
>
<div className="flex">
<GroupsIcon />
<div className="ml-1">{userGroup.name}</div>
</div>
</Bubble>
);
})}
</div>
</div>
)}
</>
)}
</Button>
{existingLlmProvider && (
<Button
type="button"
color="red"
className="ml-3"
size="xs"
icon={FiTrash}
onClick={async () => {
const response = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}`,
{
method: "DELETE",
}
);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
alert(`Failed to delete provider: ${errorMsg}`);
return;
}
</>
)}
mutate(LLM_PROVIDERS_ADMIN_URL);
onClose();
}}
>
Delete
</Button>
<div>
{/* NOTE: this is above the test button to make sure it's visible */}
{testError && (
<Text className="text-error mt-2">{testError}</Text>
)}
<div className="flex w-full mt-4">
<Button type="submit" size="xs">
{isTesting ? (
<LoadingAnimation text="Testing" />
) : existingLlmProvider ? (
"Update"
) : (
"Enable"
)}
</Button>
{existingLlmProvider && (
<Button
type="button"
color="red"
className="ml-3"
size="xs"
icon={FiTrash}
onClick={async () => {
const response = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}`,
{
method: "DELETE",
}
);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
alert(`Failed to delete provider: ${errorMsg}`);
return;
}
mutate(LLM_PROVIDERS_ADMIN_URL);
onClose();
}}
>
Delete
</Button>
)}
</div>
</div>
</div>
</Form>
)}
</Form>
);
}}
</Formik>
);
}

View File

@ -1,4 +1,5 @@
import { LoadingAnimation } from "@/components/Loading";
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
import { Button, Divider, Text } from "@tremor/react";
import { Form, Formik } from "formik";
import { FiTrash } from "react-icons/fi";
@ -6,11 +7,16 @@ import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
import {
SelectorFormField,
TextFormField,
BooleanFormField,
} from "@/components/admin/connectors/Field";
import { useState } from "react";
import { Bubble } from "@/components/Bubble";
import { GroupsIcon } from "@/components/icons/icons";
import { useSWRConfig } from "swr";
import { useUserGroups } from "@/lib/hooks";
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import * as Yup from "yup";
import isEqual from "lodash/isEqual";
@ -29,9 +35,16 @@ export function LLMProviderUpdateForm({
}) {
const { mutate } = useSWRConfig();
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
// EE only
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
const [isTesting, setIsTesting] = useState(false);
const [testError, setTestError] = useState<string>("");
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
// Define the initial values based on the provider's requirements
const initialValues = {
name: existingLlmProvider?.name ?? "",
@ -54,6 +67,8 @@ export function LLMProviderUpdateForm({
},
{} as { [key: string]: string }
),
is_public: existingLlmProvider?.is_public ?? true,
groups: existingLlmProvider?.groups ?? [],
};
const [validatedConfig, setValidatedConfig] = useState(
@ -91,6 +106,9 @@ export function LLMProviderUpdateForm({
: {}),
default_model_name: Yup.string().required("Model name is required"),
fast_default_model_name: Yup.string().nullable(),
// EE Only
is_public: Yup.boolean().required(),
groups: Yup.array().of(Yup.number()),
});
return (
@ -193,7 +211,7 @@ export function LLMProviderUpdateForm({
setSubmitting(false);
}}
>
{({ values }) => (
{({ values, setFieldValue }) => (
<Form>
<TextFormField
name="name"
@ -293,6 +311,69 @@ export function LLMProviderUpdateForm({
<Divider />
<AdvancedOptionsToggle
showAdvancedOptions={showAdvancedOptions}
setShowAdvancedOptions={setShowAdvancedOptions}
/>
{showAdvancedOptions && (
<>
{isPaidEnterpriseFeaturesEnabled && userGroups && (
<>
<BooleanFormField
small
noPadding
alignTop
name="is_public"
label="Is Public?"
subtext="If set, this LLM Provider will be available to all users. If not, only the specified User Groups will be able to use it."
/>
{userGroups && userGroups.length > 0 && !values.is_public && (
<div>
<Text>
Select which User Groups should have access to this LLM
Provider.
</Text>
<div className="flex flex-wrap gap-2 mt-2">
{userGroups.map((userGroup) => {
const isSelected = values.groups.includes(
userGroup.id
);
return (
<Bubble
key={userGroup.id}
isSelected={isSelected}
onClick={() => {
if (isSelected) {
setFieldValue(
"groups",
values.groups.filter(
(id) => id !== userGroup.id
)
);
} else {
setFieldValue("groups", [
...values.groups,
userGroup.id,
]);
}
}}
>
<div className="flex">
<GroupsIcon />
<div className="ml-1">{userGroup.name}</div>
</div>
</Bubble>
);
})}
</div>
</div>
)}
</>
)}
</>
)}
<div>
{/* NOTE: this is above the test button to make sure it's visible */}
{testError && <Text className="text-error mt-2">{testError}</Text>}

View File

@ -17,6 +17,8 @@ export interface WellKnownLLMProviderDescriptor {
llm_names: string[];
default_model: string | null;
default_fast_model: string | null;
is_public: boolean;
groups: number[];
}
export interface LLMProvider {
@ -28,6 +30,8 @@ export interface LLMProvider {
custom_config: { [key: string]: string } | null;
default_model_name: string;
fast_default_model_name: string | null;
is_public: boolean;
groups: number[];
}
export interface FullLLMProvider extends LLMProvider {
@ -44,4 +48,6 @@ export interface LLMProviderDescriptor {
default_model_name: string;
fast_default_model_name: string | null;
is_default_provider: boolean | null;
is_public: boolean;
groups: number[];
}

View File

@ -0,0 +1,26 @@
import React from "react";
import { Button } from "@tremor/react";
import { FiChevronDown, FiChevronRight } from "react-icons/fi";
interface AdvancedOptionsToggleProps {
showAdvancedOptions: boolean;
setShowAdvancedOptions: (show: boolean) => void;
}
export function AdvancedOptionsToggle({
showAdvancedOptions,
setShowAdvancedOptions,
}: AdvancedOptionsToggleProps) {
return (
<Button
type="button"
variant="light"
size="xs"
icon={showAdvancedOptions ? FiChevronDown : FiChevronRight}
onClick={() => setShowAdvancedOptions(!showAdvancedOptions)}
className="mb-4 text-xs text-text-500 hover:text-text-400"
>
Advanced Options
</Button>
);
}