Added ability to control LLM access based on group (#1870)

* Added ability to control LLM access based on group

* completed relationship deletion

* cleaned up function

* added comments

* fixed frontend strings

* mypy fixes

* added case handling for deletion of user groups

* hidden advanced options now

* removed unnecessary code
This commit is contained in:
hagen-danswer
2024-07-21 21:31:44 -07:00
committed by GitHub
parent 2f5f19642e
commit 1b49d17239
12 changed files with 558 additions and 221 deletions

View File

@@ -0,0 +1,41 @@
"""add_llm_group_permissions_control
Revision ID: 795b20b85b4b
Revises: 05c07bf07c00
Create Date: 2024-07-19 11:54:35.701558
"""
from alembic import op
import sqlalchemy as sa
revision = "795b20b85b4b"
down_revision = "05c07bf07c00"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"llm_provider__user_group",
sa.Column("llm_provider_id", sa.Integer(), nullable=False),
sa.Column("user_group_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["llm_provider_id"],
["llm_provider.id"],
),
sa.ForeignKeyConstraint(
["user_group_id"],
["user_group.id"],
),
sa.PrimaryKeyConstraint("llm_provider_id", "user_group_id"),
)
op.add_column(
"llm_provider",
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="true"),
)
def downgrade() -> None:
op.drop_table("llm_provider__user_group")
op.drop_column("llm_provider", "is_public")

View File

@@ -1,15 +1,42 @@
from sqlalchemy import delete from sqlalchemy import delete
from sqlalchemy import or_
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from danswer.db.models import LLMProvider as LLMProviderModel from danswer.db.models import LLMProvider as LLMProviderModel
from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.server.manage.embedding.models import CloudEmbeddingProvider from danswer.server.manage.embedding.models import CloudEmbeddingProvider
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from danswer.server.manage.llm.models import FullLLMProvider from danswer.server.manage.llm.models import FullLLMProvider
from danswer.server.manage.llm.models import LLMProviderUpsertRequest from danswer.server.manage.llm.models import LLMProviderUpsertRequest
def update_group_llm_provider_relationships(
llm_provider_id: int,
group_ids: list[int] | None,
db_session: Session,
) -> None:
# Delete existing relationships
db_session.query(LLMProvider__UserGroup).filter(
LLMProvider__UserGroup.llm_provider_id == llm_provider_id
).delete(synchronize_session="fetch")
# Add new relationships from given group_ids
if group_ids:
new_relationships = [
LLMProvider__UserGroup(
llm_provider_id=llm_provider_id,
user_group_id=group_id,
)
for group_id in group_ids
]
db_session.add_all(new_relationships)
db_session.commit()
def upsert_cloud_embedding_provider( def upsert_cloud_embedding_provider(
db_session: Session, provider: CloudEmbeddingProviderCreationRequest db_session: Session, provider: CloudEmbeddingProviderCreationRequest
) -> CloudEmbeddingProvider: ) -> CloudEmbeddingProvider:
@@ -36,36 +63,33 @@ def upsert_llm_provider(
existing_llm_provider = db_session.scalar( existing_llm_provider = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name) select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
) )
if existing_llm_provider:
existing_llm_provider.provider = llm_provider.provider
existing_llm_provider.api_key = llm_provider.api_key
existing_llm_provider.api_base = llm_provider.api_base
existing_llm_provider.api_version = llm_provider.api_version
existing_llm_provider.custom_config = llm_provider.custom_config
existing_llm_provider.default_model_name = llm_provider.default_model_name
existing_llm_provider.fast_default_model_name = (
llm_provider.fast_default_model_name
)
existing_llm_provider.model_names = llm_provider.model_names
db_session.commit()
return FullLLMProvider.from_model(existing_llm_provider)
# if it does not exist, create a new entry
llm_provider_model = LLMProviderModel(
name=llm_provider.name,
provider=llm_provider.provider,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
default_model_name=llm_provider.default_model_name,
fast_default_model_name=llm_provider.fast_default_model_name,
model_names=llm_provider.model_names,
is_default_provider=None,
)
db_session.add(llm_provider_model)
db_session.commit()
return FullLLMProvider.from_model(llm_provider_model) if not existing_llm_provider:
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
db_session.add(existing_llm_provider)
existing_llm_provider.provider = llm_provider.provider
existing_llm_provider.api_key = llm_provider.api_key
existing_llm_provider.api_base = llm_provider.api_base
existing_llm_provider.api_version = llm_provider.api_version
existing_llm_provider.custom_config = llm_provider.custom_config
existing_llm_provider.default_model_name = llm_provider.default_model_name
existing_llm_provider.fast_default_model_name = llm_provider.fast_default_model_name
existing_llm_provider.model_names = llm_provider.model_names
existing_llm_provider.is_public = llm_provider.is_public
if not existing_llm_provider.id:
# If its not already in the db, we need to generate an ID by flushing
db_session.flush()
# Make sure the relationship table stays up to date
update_group_llm_provider_relationships(
llm_provider_id=existing_llm_provider.id,
group_ids=llm_provider.groups,
db_session=db_session,
)
return FullLLMProvider.from_model(existing_llm_provider)
def fetch_existing_embedding_providers( def fetch_existing_embedding_providers(
@@ -74,8 +98,29 @@ def fetch_existing_embedding_providers(
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all()) return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]: def fetch_existing_llm_providers(
return list(db_session.scalars(select(LLMProviderModel)).all()) db_session: Session,
user: User | None = None,
) -> list[LLMProviderModel]:
if not user:
return list(db_session.scalars(select(LLMProviderModel)).all())
stmt = select(LLMProviderModel).distinct()
user_groups_subquery = (
select(User__UserGroup.user_group_id)
.where(User__UserGroup.user_id == user.id)
.subquery()
)
access_conditions = or_(
LLMProviderModel.is_public,
LLMProviderModel.id.in_( # User is part of a group that has access
select(LLMProvider__UserGroup.llm_provider_id).where(
LLMProvider__UserGroup.user_group_id.in_(user_groups_subquery) # type: ignore
)
),
)
stmt = stmt.where(access_conditions)
return list(db_session.scalars(stmt).all())
def fetch_embedding_provider( def fetch_embedding_provider(
@@ -119,6 +164,13 @@ def remove_embedding_provider(
def remove_llm_provider(db_session: Session, provider_id: int) -> None: def remove_llm_provider(db_session: Session, provider_id: int) -> None:
# Remove LLMProvider's dependent relationships
db_session.execute(
delete(LLMProvider__UserGroup).where(
LLMProvider__UserGroup.llm_provider_id == provider_id
)
)
# Remove LLMProvider
db_session.execute( db_session.execute(
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id) delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
) )

View File

@@ -932,6 +932,13 @@ class LLMProvider(Base):
# should only be set for a single provider # should only be set for a single provider
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True) is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
# EE only
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
groups: Mapped[list["UserGroup"]] = relationship(
"UserGroup",
secondary="llm_provider__user_group",
viewonly=True,
)
class CloudEmbeddingProvider(Base): class CloudEmbeddingProvider(Base):
@@ -1109,7 +1116,6 @@ class Persona(Base):
# where lower value IDs (e.g. created earlier) are displayed first # where lower value IDs (e.g. created earlier) are displayed first
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=None) display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=None)
deleted: Mapped[bool] = mapped_column(Boolean, default=False) deleted: Mapped[bool] = mapped_column(Boolean, default=False)
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
# These are only defaults, users can select from all if desired # These are only defaults, users can select from all if desired
prompts: Mapped[list[Prompt]] = relationship( prompts: Mapped[list[Prompt]] = relationship(
@@ -1137,6 +1143,7 @@ class Persona(Base):
viewonly=True, viewonly=True,
) )
# EE only # EE only
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
groups: Mapped[list["UserGroup"]] = relationship( groups: Mapped[list["UserGroup"]] = relationship(
"UserGroup", "UserGroup",
secondary="persona__user_group", secondary="persona__user_group",
@@ -1360,6 +1367,17 @@ class Persona__UserGroup(Base):
) )
class LLMProvider__UserGroup(Base):
__tablename__ = "llm_provider__user_group"
llm_provider_id: Mapped[int] = mapped_column(
ForeignKey("llm_provider.id"), primary_key=True
)
user_group_id: Mapped[int] = mapped_column(
ForeignKey("user_group.id"), primary_key=True
)
class DocumentSet__UserGroup(Base): class DocumentSet__UserGroup(Base):
__tablename__ = "document_set__user_group" __tablename__ = "document_set__user_group"

View File

@@ -103,6 +103,7 @@ def port_api_key_to_postgres() -> None:
default_model_name=default_model_name, default_model_name=default_model_name,
fast_default_model_name=default_fast_model_name, fast_default_model_name=default_fast_model_name,
model_names=None, model_names=None,
is_public=True,
) )
llm_provider = upsert_llm_provider(db_session, llm_provider_upsert) llm_provider = upsert_llm_provider(db_session, llm_provider_upsert)
update_default_provider(db_session, llm_provider.id) update_default_provider(db_session, llm_provider.id)

View File

@@ -70,6 +70,7 @@ def load_llm_providers(db_session: Session) -> None:
FAST_GEN_AI_MODEL_VERSION or well_known_provider.default_fast_model FAST_GEN_AI_MODEL_VERSION or well_known_provider.default_fast_model
), ),
model_names=model_names, model_names=model_names,
is_public=True,
) )
llm_provider = upsert_llm_provider(db_session, llm_provider_request) llm_provider = upsert_llm_provider(db_session, llm_provider_request)
update_default_provider(db_session, llm_provider.id) update_default_provider(db_session, llm_provider.id)

View File

@@ -147,10 +147,10 @@ def set_provider_as_default(
@basic_router.get("/provider") @basic_router.get("/provider")
def list_llm_provider_basics( def list_llm_provider_basics(
_: User | None = Depends(current_user), user: User | None = Depends(current_user),
db_session: Session = Depends(get_session), db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]: ) -> list[LLMProviderDescriptor]:
return [ return [
LLMProviderDescriptor.from_model(llm_provider_model) LLMProviderDescriptor.from_model(llm_provider_model)
for llm_provider_model in fetch_existing_llm_providers(db_session) for llm_provider_model in fetch_existing_llm_providers(db_session, user)
] ]

View File

@@ -60,6 +60,8 @@ class LLMProvider(BaseModel):
custom_config: dict[str, str] | None custom_config: dict[str, str] | None
default_model_name: str default_model_name: str
fast_default_model_name: str | None fast_default_model_name: str | None
is_public: bool
groups: list[int] | None = None
class LLMProviderUpsertRequest(LLMProvider): class LLMProviderUpsertRequest(LLMProvider):
@@ -91,4 +93,6 @@ class FullLLMProvider(LLMProvider):
or fetch_models_for_provider(llm_provider_model.provider) or fetch_models_for_provider(llm_provider_model.provider)
or [llm_provider_model.default_model_name] or [llm_provider_model.default_model_name]
), ),
is_public=llm_provider_model.is_public,
groups=[group.id for group in llm_provider_model.groups],
) )

View File

@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
from danswer.db.models import ConnectorCredentialPair from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Document from danswer.db.models import Document
from danswer.db.models import DocumentByConnectorCredentialPair from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import TokenRateLimit__UserGroup from danswer.db.models import TokenRateLimit__UserGroup
from danswer.db.models import User from danswer.db.models import User
from danswer.db.models import User__UserGroup from danswer.db.models import User__UserGroup
@@ -194,6 +195,15 @@ def _cleanup_user__user_group_relationships__no_commit(
db_session.delete(user__user_group_relationship) db_session.delete(user__user_group_relationship)
def _cleanup_llm_provider__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(LLMProvider__UserGroup).filter(
LLMProvider__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)
def _mark_user_group__cc_pair_relationships_outdated__no_commit( def _mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session: Session, user_group_id: int db_session: Session, user_group_id: int
) -> None: ) -> None:
@@ -316,6 +326,9 @@ def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> Non
def delete_user_group(db_session: Session, user_group: UserGroup) -> None: def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
_cleanup_llm_provider__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group.id
)
_cleanup_user__user_group_relationships__no_commit( _cleanup_user__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group.id db_session=db_session, user_group_id=user_group.id
) )

View File

@@ -1,5 +1,6 @@
import { LoadingAnimation } from "@/components/Loading"; import { LoadingAnimation } from "@/components/Loading";
import { Button, Divider, Text } from "@tremor/react"; import { Button, Divider, Text } from "@tremor/react";
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
import { import {
ArrayHelpers, ArrayHelpers,
ErrorMessage, ErrorMessage,
@@ -15,11 +16,16 @@ import {
SubLabel, SubLabel,
TextArrayField, TextArrayField,
TextFormField, TextFormField,
BooleanFormField,
} from "@/components/admin/connectors/Field"; } from "@/components/admin/connectors/Field";
import { useState } from "react"; import { useState } from "react";
import { Bubble } from "@/components/Bubble";
import { GroupsIcon } from "@/components/icons/icons";
import { useSWRConfig } from "swr"; import { useSWRConfig } from "swr";
import { useUserGroups } from "@/lib/hooks";
import { FullLLMProvider } from "./interfaces"; import { FullLLMProvider } from "./interfaces";
import { PopupSpec } from "@/components/admin/connectors/Popup"; import { PopupSpec } from "@/components/admin/connectors/Popup";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import * as Yup from "yup"; import * as Yup from "yup";
import isEqual from "lodash/isEqual"; import isEqual from "lodash/isEqual";
@@ -44,9 +50,16 @@ export function CustomLLMProviderUpdateForm({
}) { }) {
const { mutate } = useSWRConfig(); const { mutate } = useSWRConfig();
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
// EE only
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
const [isTesting, setIsTesting] = useState(false); const [isTesting, setIsTesting] = useState(false);
const [testError, setTestError] = useState<string>(""); const [testError, setTestError] = useState<string>("");
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
// Define the initial values based on the provider's requirements // Define the initial values based on the provider's requirements
const initialValues = { const initialValues = {
name: existingLlmProvider?.name ?? "", name: existingLlmProvider?.name ?? "",
@@ -61,6 +74,8 @@ export function CustomLLMProviderUpdateForm({
custom_config_list: existingLlmProvider?.custom_config custom_config_list: existingLlmProvider?.custom_config
? Object.entries(existingLlmProvider.custom_config) ? Object.entries(existingLlmProvider.custom_config)
: [], : [],
is_public: existingLlmProvider?.is_public ?? true,
groups: existingLlmProvider?.groups ?? [],
}; };
// Setup validation schema if required // Setup validation schema if required
@@ -74,6 +89,9 @@ export function CustomLLMProviderUpdateForm({
default_model_name: Yup.string().required("Model name is required"), default_model_name: Yup.string().required("Model name is required"),
fast_default_model_name: Yup.string().nullable(), fast_default_model_name: Yup.string().nullable(),
custom_config_list: Yup.array(), custom_config_list: Yup.array(),
// EE Only
is_public: Yup.boolean().required(),
groups: Yup.array().of(Yup.number()),
}); });
return ( return (
@@ -97,6 +115,9 @@ export function CustomLLMProviderUpdateForm({
return; return;
} }
// don't set groups if marked as public
const groups = values.is_public ? [] : values.groups;
// test the configuration // test the configuration
if (!isEqual(values, initialValues)) { if (!isEqual(values, initialValues)) {
setIsTesting(true); setIsTesting(true);
@@ -188,93 +209,97 @@ export function CustomLLMProviderUpdateForm({
setSubmitting(false); setSubmitting(false);
}} }}
> >
{({ values }) => ( {({ values, setFieldValue }) => {
<Form> return (
<TextFormField <Form>
name="name" <TextFormField
label="Display Name" name="name"
subtext="A name which you can use to identify this provider when selecting it in the UI." label="Display Name"
placeholder="Display Name" subtext="A name which you can use to identify this provider when selecting it in the UI."
/> placeholder="Display Name"
/>
<Divider /> <Divider />
<TextFormField <TextFormField
name="provider" name="provider"
label="Provider Name" label="Provider Name"
subtext={ subtext={
<>
Should be one of the providers listed at{" "}
<a
target="_blank"
href="https://docs.litellm.ai/docs/providers"
className="text-link"
>
https://docs.litellm.ai/docs/providers
</a>
.
</>
}
placeholder="Name of the custom provider"
/>
<Divider />
<SubLabel>
Fill in the following as is needed. Refer to the LiteLLM
documentation for the model provider name specified above in order
to determine which fields are required.
</SubLabel>
<TextFormField
name="api_key"
label="[Optional] API Key"
placeholder="API Key"
type="password"
/>
<TextFormField
name="api_base"
label="[Optional] API Base"
placeholder="API Base"
/>
<TextFormField
name="api_version"
label="[Optional] API Version"
placeholder="API Version"
/>
<Label>[Optional] Custom Configs</Label>
<SubLabel>
<> <>
Should be one of the providers listed at{" "} <div>
<a Additional configurations needed by the model provider. Are
target="_blank" passed to litellm via environment variables.
href="https://docs.litellm.ai/docs/providers" </div>
className="text-link"
> <div className="mt-2">
https://docs.litellm.ai/docs/providers For example, when configuring the Cloudflare provider, you
</a> would need to set `CLOUDFLARE_ACCOUNT_ID` as the key and your
. Cloudflare account ID as the value.
</div>
</> </>
} </SubLabel>
placeholder="Name of the custom provider"
/>
<Divider /> <FieldArray
name="custom_config_list"
<SubLabel> render={(arrayHelpers: ArrayHelpers<any[]>) => (
Fill in the following as is needed. Refer to the LiteLLM <div>
documentation for the model provider name specified above in order {values.custom_config_list.map((_, index) => {
to determine which fields are required. return (
</SubLabel> <div
key={index}
<TextFormField className={index === 0 ? "mt-2" : "mt-6"}
name="api_key" >
label="[Optional] API Key" <div className="flex">
placeholder="API Key" <div className="w-full mr-6 border border-border p-3 rounded">
type="password" <div>
/> <Label>Key</Label>
<Field
<TextFormField name={`custom_config_list[${index}][0]`}
name="api_base" className={`
label="[Optional] API Base"
placeholder="API Base"
/>
<TextFormField
name="api_version"
label="[Optional] API Version"
placeholder="API Version"
/>
<Label>[Optional] Custom Configs</Label>
<SubLabel>
<>
<div>
Additional configurations needed by the model provider. Are
passed to litellm via environment variables.
</div>
<div className="mt-2">
For example, when configuring the Cloudflare provider, you would
need to set `CLOUDFLARE_ACCOUNT_ID` as the key and your
Cloudflare account ID as the value.
</div>
</>
</SubLabel>
<FieldArray
name="custom_config_list"
render={(arrayHelpers: ArrayHelpers<any[]>) => (
<div>
{values.custom_config_list.map((_, index) => {
return (
<div key={index} className={index === 0 ? "mt-2" : "mt-6"}>
<div className="flex">
<div className="w-full mr-6 border border-border p-3 rounded">
<div>
<Label>Key</Label>
<Field
name={`custom_config_list[${index}][0]`}
className={`
border border
border-border border-border
bg-background bg-background
@@ -284,20 +309,20 @@ export function CustomLLMProviderUpdateForm({
px-3 px-3
mr-4 mr-4
`} `}
autoComplete="off" autoComplete="off"
/> />
<ErrorMessage <ErrorMessage
name={`custom_config_list[${index}][0]`} name={`custom_config_list[${index}][0]`}
component="div" component="div"
className="text-error text-sm mt-1" className="text-error text-sm mt-1"
/> />
</div> </div>
<div className="mt-3"> <div className="mt-3">
<Label>Value</Label> <Label>Value</Label>
<Field <Field
name={`custom_config_list[${index}][1]`} name={`custom_config_list[${index}][1]`}
className={` className={`
border border
border-border border-border
bg-background bg-background
@@ -307,121 +332,190 @@ export function CustomLLMProviderUpdateForm({
px-3 px-3
mr-4 mr-4
`} `}
autoComplete="off" autoComplete="off"
/> />
<ErrorMessage <ErrorMessage
name={`custom_config_list[${index}][1]`} name={`custom_config_list[${index}][1]`}
component="div" component="div"
className="text-error text-sm mt-1" className="text-error text-sm mt-1"
/>
</div>
</div>
<div className="my-auto">
<FiX
className="my-auto w-10 h-10 cursor-pointer hover:bg-hover rounded p-2"
onClick={() => arrayHelpers.remove(index)}
/> />
</div> </div>
</div> </div>
<div className="my-auto">
<FiX
className="my-auto w-10 h-10 cursor-pointer hover:bg-hover rounded p-2"
onClick={() => arrayHelpers.remove(index)}
/>
</div>
</div> </div>
</div> );
); })}
})}
<Button <Button
onClick={() => { onClick={() => {
arrayHelpers.push(["", ""]); arrayHelpers.push(["", ""]);
}} }}
className="mt-3" className="mt-3"
color="green" color="green"
size="xs" size="xs"
type="button" type="button"
icon={FiPlus} icon={FiPlus}
> >
Add New Add New
</Button> </Button>
</div> </div>
)} )}
/> />
<Divider /> <Divider />
<TextArrayField <TextArrayField
name="model_names" name="model_names"
label="Model Names" label="Model Names"
values={values} values={values}
subtext={`List the individual models that you want to make subtext={`List the individual models that you want to make
available as a part of this provider. At least one must be specified. available as a part of this provider. At least one must be specified.
As an example, for OpenAI one model might be "gpt-4".`} As an example, for OpenAI one model might be "gpt-4".`}
/> />
<Divider /> <Divider />
<TextFormField <TextFormField
name="default_model_name" name="default_model_name"
subtext={` subtext={`
The model to use by default for this provider unless The model to use by default for this provider unless
otherwise specified. Must be one of the models listed otherwise specified. Must be one of the models listed
above.`} above.`}
label="Default Model" label="Default Model"
placeholder="E.g. gpt-4" placeholder="E.g. gpt-4"
/> />
<TextFormField <TextFormField
name="fast_default_model_name" name="fast_default_model_name"
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\` subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
for this provider. If not set, will use for this provider. If not set, will use
the Default Model configured above.`} the Default Model configured above.`}
label="[Optional] Fast Model" label="[Optional] Fast Model"
placeholder="E.g. gpt-4" placeholder="E.g. gpt-4"
/> />
<Divider /> <Divider />
<div> <AdvancedOptionsToggle
{/* NOTE: this is above the test button to make sure it's visible */} showAdvancedOptions={showAdvancedOptions}
{testError && <Text className="text-error mt-2">{testError}</Text>} setShowAdvancedOptions={setShowAdvancedOptions}
/>
<div className="flex w-full mt-4"> {showAdvancedOptions && (
<Button type="submit" size="xs"> <>
{isTesting ? ( {isPaidEnterpriseFeaturesEnabled && userGroups && (
<LoadingAnimation text="Testing" /> <>
) : existingLlmProvider ? ( <BooleanFormField
"Update" small
) : ( noPadding
"Enable" alignTop
name="is_public"
label="Is Public?"
subtext="If set, this LLM Provider will be available to all users. If not, only the specified User Groups will be able to use it."
/>
{userGroups &&
userGroups.length > 0 &&
!values.is_public && (
<div>
<Text>
Select which User Groups should have access to this
LLM Provider.
</Text>
<div className="flex flex-wrap gap-2 mt-2">
{userGroups.map((userGroup) => {
const isSelected = values.groups.includes(
userGroup.id
);
return (
<Bubble
key={userGroup.id}
isSelected={isSelected}
onClick={() => {
if (isSelected) {
setFieldValue(
"groups",
values.groups.filter(
(id) => id !== userGroup.id
)
);
} else {
setFieldValue("groups", [
...values.groups,
userGroup.id,
]);
}
}}
>
<div className="flex">
<GroupsIcon />
<div className="ml-1">{userGroup.name}</div>
</div>
</Bubble>
);
})}
</div>
</div>
)}
</>
)} )}
</Button> </>
{existingLlmProvider && ( )}
<Button
type="button"
color="red"
className="ml-3"
size="xs"
icon={FiTrash}
onClick={async () => {
const response = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}`,
{
method: "DELETE",
}
);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
alert(`Failed to delete provider: ${errorMsg}`);
return;
}
mutate(LLM_PROVIDERS_ADMIN_URL); <div>
onClose(); {/* NOTE: this is above the test button to make sure it's visible */}
}} {testError && (
> <Text className="text-error mt-2">{testError}</Text>
Delete
</Button>
)} )}
<div className="flex w-full mt-4">
<Button type="submit" size="xs">
{isTesting ? (
<LoadingAnimation text="Testing" />
) : existingLlmProvider ? (
"Update"
) : (
"Enable"
)}
</Button>
{existingLlmProvider && (
<Button
type="button"
color="red"
className="ml-3"
size="xs"
icon={FiTrash}
onClick={async () => {
const response = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}`,
{
method: "DELETE",
}
);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
alert(`Failed to delete provider: ${errorMsg}`);
return;
}
mutate(LLM_PROVIDERS_ADMIN_URL);
onClose();
}}
>
Delete
</Button>
)}
</div>
</div> </div>
</div> </Form>
</Form> );
)} }}
</Formik> </Formik>
); );
} }

View File

@@ -1,4 +1,5 @@
import { LoadingAnimation } from "@/components/Loading"; import { LoadingAnimation } from "@/components/Loading";
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
import { Button, Divider, Text } from "@tremor/react"; import { Button, Divider, Text } from "@tremor/react";
import { Form, Formik } from "formik"; import { Form, Formik } from "formik";
import { FiTrash } from "react-icons/fi"; import { FiTrash } from "react-icons/fi";
@@ -6,11 +7,16 @@ import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
import { import {
SelectorFormField, SelectorFormField,
TextFormField, TextFormField,
BooleanFormField,
} from "@/components/admin/connectors/Field"; } from "@/components/admin/connectors/Field";
import { useState } from "react"; import { useState } from "react";
import { Bubble } from "@/components/Bubble";
import { GroupsIcon } from "@/components/icons/icons";
import { useSWRConfig } from "swr"; import { useSWRConfig } from "swr";
import { useUserGroups } from "@/lib/hooks";
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces"; import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
import { PopupSpec } from "@/components/admin/connectors/Popup"; import { PopupSpec } from "@/components/admin/connectors/Popup";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import * as Yup from "yup"; import * as Yup from "yup";
import isEqual from "lodash/isEqual"; import isEqual from "lodash/isEqual";
@@ -29,9 +35,16 @@ export function LLMProviderUpdateForm({
}) { }) {
const { mutate } = useSWRConfig(); const { mutate } = useSWRConfig();
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
// EE only
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
const [isTesting, setIsTesting] = useState(false); const [isTesting, setIsTesting] = useState(false);
const [testError, setTestError] = useState<string>(""); const [testError, setTestError] = useState<string>("");
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
// Define the initial values based on the provider's requirements // Define the initial values based on the provider's requirements
const initialValues = { const initialValues = {
name: existingLlmProvider?.name ?? "", name: existingLlmProvider?.name ?? "",
@@ -54,6 +67,8 @@ export function LLMProviderUpdateForm({
}, },
{} as { [key: string]: string } {} as { [key: string]: string }
), ),
is_public: existingLlmProvider?.is_public ?? true,
groups: existingLlmProvider?.groups ?? [],
}; };
const [validatedConfig, setValidatedConfig] = useState( const [validatedConfig, setValidatedConfig] = useState(
@@ -91,6 +106,9 @@ export function LLMProviderUpdateForm({
: {}), : {}),
default_model_name: Yup.string().required("Model name is required"), default_model_name: Yup.string().required("Model name is required"),
fast_default_model_name: Yup.string().nullable(), fast_default_model_name: Yup.string().nullable(),
// EE Only
is_public: Yup.boolean().required(),
groups: Yup.array().of(Yup.number()),
}); });
return ( return (
@@ -193,7 +211,7 @@ export function LLMProviderUpdateForm({
setSubmitting(false); setSubmitting(false);
}} }}
> >
{({ values }) => ( {({ values, setFieldValue }) => (
<Form> <Form>
<TextFormField <TextFormField
name="name" name="name"
@@ -293,6 +311,69 @@ export function LLMProviderUpdateForm({
<Divider /> <Divider />
<AdvancedOptionsToggle
showAdvancedOptions={showAdvancedOptions}
setShowAdvancedOptions={setShowAdvancedOptions}
/>
{showAdvancedOptions && (
<>
{isPaidEnterpriseFeaturesEnabled && userGroups && (
<>
<BooleanFormField
small
noPadding
alignTop
name="is_public"
label="Is Public?"
subtext="If set, this LLM Provider will be available to all users. If not, only the specified User Groups will be able to use it."
/>
{userGroups && userGroups.length > 0 && !values.is_public && (
<div>
<Text>
Select which User Groups should have access to this LLM
Provider.
</Text>
<div className="flex flex-wrap gap-2 mt-2">
{userGroups.map((userGroup) => {
const isSelected = values.groups.includes(
userGroup.id
);
return (
<Bubble
key={userGroup.id}
isSelected={isSelected}
onClick={() => {
if (isSelected) {
setFieldValue(
"groups",
values.groups.filter(
(id) => id !== userGroup.id
)
);
} else {
setFieldValue("groups", [
...values.groups,
userGroup.id,
]);
}
}}
>
<div className="flex">
<GroupsIcon />
<div className="ml-1">{userGroup.name}</div>
</div>
</Bubble>
);
})}
</div>
</div>
)}
</>
)}
</>
)}
<div> <div>
{/* NOTE: this is above the test button to make sure it's visible */} {/* NOTE: this is above the test button to make sure it's visible */}
{testError && <Text className="text-error mt-2">{testError}</Text>} {testError && <Text className="text-error mt-2">{testError}</Text>}

View File

@@ -17,6 +17,8 @@ export interface WellKnownLLMProviderDescriptor {
llm_names: string[]; llm_names: string[];
default_model: string | null; default_model: string | null;
default_fast_model: string | null; default_fast_model: string | null;
is_public: boolean;
groups: number[];
} }
export interface LLMProvider { export interface LLMProvider {
@@ -28,6 +30,8 @@ export interface LLMProvider {
custom_config: { [key: string]: string } | null; custom_config: { [key: string]: string } | null;
default_model_name: string; default_model_name: string;
fast_default_model_name: string | null; fast_default_model_name: string | null;
is_public: boolean;
groups: number[];
} }
export interface FullLLMProvider extends LLMProvider { export interface FullLLMProvider extends LLMProvider {
@@ -44,4 +48,6 @@ export interface LLMProviderDescriptor {
default_model_name: string; default_model_name: string;
fast_default_model_name: string | null; fast_default_model_name: string | null;
is_default_provider: boolean | null; is_default_provider: boolean | null;
is_public: boolean;
groups: number[];
} }

View File

@@ -0,0 +1,26 @@
import React from "react";
import { Button } from "@tremor/react";
import { FiChevronDown, FiChevronRight } from "react-icons/fi";
interface AdvancedOptionsToggleProps {
showAdvancedOptions: boolean;
setShowAdvancedOptions: (show: boolean) => void;
}
export function AdvancedOptionsToggle({
showAdvancedOptions,
setShowAdvancedOptions,
}: AdvancedOptionsToggleProps) {
return (
<Button
type="button"
variant="light"
size="xs"
icon={showAdvancedOptions ? FiChevronDown : FiChevronRight}
onClick={() => setShowAdvancedOptions(!showAdvancedOptions)}
className="mb-4 text-xs text-text-500 hover:text-text-400"
>
Advanced Options
</Button>
);
}