mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-09 12:30:49 +02:00
Added ability to control LLM access based on group (#1870)
* Added ability to control LLM access based on group * completed relationship deletion * cleaned up function * added comments * fixed frontend strings * mypy fixes * added case handling for deletion of user groups * hidden advanced options now * removed unnecessary code
This commit is contained in:
parent
2f5f19642e
commit
1b49d17239
@ -0,0 +1,41 @@
|
||||
"""add_llm_group_permissions_control
|
||||
|
||||
Revision ID: 795b20b85b4b
|
||||
Revises: 05c07bf07c00
|
||||
Create Date: 2024-07-19 11:54:35.701558
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision = "795b20b85b4b"
|
||||
down_revision = "05c07bf07c00"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"llm_provider__user_group",
|
||||
sa.Column("llm_provider_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_group_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["llm_provider_id"],
|
||||
["llm_provider.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_group_id"],
|
||||
["user_group.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("llm_provider_id", "user_group_id"),
|
||||
)
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="true"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("llm_provider__user_group")
|
||||
op.drop_column("llm_provider", "is_public")
|
@ -1,15 +1,42 @@
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||
from danswer.db.models import LLMProvider__UserGroup
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from danswer.server.manage.llm.models import FullLLMProvider
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
|
||||
|
||||
def update_group_llm_provider_relationships(
|
||||
llm_provider_id: int,
|
||||
group_ids: list[int] | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
# Delete existing relationships
|
||||
db_session.query(LLMProvider__UserGroup).filter(
|
||||
LLMProvider__UserGroup.llm_provider_id == llm_provider_id
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
# Add new relationships from given group_ids
|
||||
if group_ids:
|
||||
new_relationships = [
|
||||
LLMProvider__UserGroup(
|
||||
llm_provider_id=llm_provider_id,
|
||||
user_group_id=group_id,
|
||||
)
|
||||
for group_id in group_ids
|
||||
]
|
||||
db_session.add_all(new_relationships)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def upsert_cloud_embedding_provider(
|
||||
db_session: Session, provider: CloudEmbeddingProviderCreationRequest
|
||||
) -> CloudEmbeddingProvider:
|
||||
@ -36,36 +63,33 @@ def upsert_llm_provider(
|
||||
existing_llm_provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
|
||||
)
|
||||
if existing_llm_provider:
|
||||
existing_llm_provider.provider = llm_provider.provider
|
||||
existing_llm_provider.api_key = llm_provider.api_key
|
||||
existing_llm_provider.api_base = llm_provider.api_base
|
||||
existing_llm_provider.api_version = llm_provider.api_version
|
||||
existing_llm_provider.custom_config = llm_provider.custom_config
|
||||
existing_llm_provider.default_model_name = llm_provider.default_model_name
|
||||
existing_llm_provider.fast_default_model_name = (
|
||||
llm_provider.fast_default_model_name
|
||||
)
|
||||
existing_llm_provider.model_names = llm_provider.model_names
|
||||
db_session.commit()
|
||||
return FullLLMProvider.from_model(existing_llm_provider)
|
||||
# if it does not exist, create a new entry
|
||||
llm_provider_model = LLMProviderModel(
|
||||
name=llm_provider.name,
|
||||
provider=llm_provider.provider,
|
||||
api_key=llm_provider.api_key,
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
default_model_name=llm_provider.default_model_name,
|
||||
fast_default_model_name=llm_provider.fast_default_model_name,
|
||||
model_names=llm_provider.model_names,
|
||||
is_default_provider=None,
|
||||
)
|
||||
db_session.add(llm_provider_model)
|
||||
db_session.commit()
|
||||
|
||||
return FullLLMProvider.from_model(llm_provider_model)
|
||||
if not existing_llm_provider:
|
||||
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
|
||||
db_session.add(existing_llm_provider)
|
||||
|
||||
existing_llm_provider.provider = llm_provider.provider
|
||||
existing_llm_provider.api_key = llm_provider.api_key
|
||||
existing_llm_provider.api_base = llm_provider.api_base
|
||||
existing_llm_provider.api_version = llm_provider.api_version
|
||||
existing_llm_provider.custom_config = llm_provider.custom_config
|
||||
existing_llm_provider.default_model_name = llm_provider.default_model_name
|
||||
existing_llm_provider.fast_default_model_name = llm_provider.fast_default_model_name
|
||||
existing_llm_provider.model_names = llm_provider.model_names
|
||||
existing_llm_provider.is_public = llm_provider.is_public
|
||||
|
||||
if not existing_llm_provider.id:
|
||||
# If its not already in the db, we need to generate an ID by flushing
|
||||
db_session.flush()
|
||||
|
||||
# Make sure the relationship table stays up to date
|
||||
update_group_llm_provider_relationships(
|
||||
llm_provider_id=existing_llm_provider.id,
|
||||
group_ids=llm_provider.groups,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return FullLLMProvider.from_model(existing_llm_provider)
|
||||
|
||||
|
||||
def fetch_existing_embedding_providers(
|
||||
@ -74,8 +98,29 @@ def fetch_existing_embedding_providers(
|
||||
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
|
||||
|
||||
|
||||
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
|
||||
return list(db_session.scalars(select(LLMProviderModel)).all())
|
||||
def fetch_existing_llm_providers(
|
||||
db_session: Session,
|
||||
user: User | None = None,
|
||||
) -> list[LLMProviderModel]:
|
||||
if not user:
|
||||
return list(db_session.scalars(select(LLMProviderModel)).all())
|
||||
stmt = select(LLMProviderModel).distinct()
|
||||
user_groups_subquery = (
|
||||
select(User__UserGroup.user_group_id)
|
||||
.where(User__UserGroup.user_id == user.id)
|
||||
.subquery()
|
||||
)
|
||||
access_conditions = or_(
|
||||
LLMProviderModel.is_public,
|
||||
LLMProviderModel.id.in_( # User is part of a group that has access
|
||||
select(LLMProvider__UserGroup.llm_provider_id).where(
|
||||
LLMProvider__UserGroup.user_group_id.in_(user_groups_subquery) # type: ignore
|
||||
)
|
||||
),
|
||||
)
|
||||
stmt = stmt.where(access_conditions)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def fetch_embedding_provider(
|
||||
@ -119,6 +164,13 @@ def remove_embedding_provider(
|
||||
|
||||
|
||||
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
|
||||
# Remove LLMProvider's dependent relationships
|
||||
db_session.execute(
|
||||
delete(LLMProvider__UserGroup).where(
|
||||
LLMProvider__UserGroup.llm_provider_id == provider_id
|
||||
)
|
||||
)
|
||||
# Remove LLMProvider
|
||||
db_session.execute(
|
||||
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||
)
|
||||
|
@ -932,6 +932,13 @@ class LLMProvider(Base):
|
||||
|
||||
# should only be set for a single provider
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
groups: Mapped[list["UserGroup"]] = relationship(
|
||||
"UserGroup",
|
||||
secondary="llm_provider__user_group",
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(Base):
|
||||
@ -1109,7 +1116,6 @@ class Persona(Base):
|
||||
# where lower value IDs (e.g. created earlier) are displayed first
|
||||
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=None)
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
# These are only defaults, users can select from all if desired
|
||||
prompts: Mapped[list[Prompt]] = relationship(
|
||||
@ -1137,6 +1143,7 @@ class Persona(Base):
|
||||
viewonly=True,
|
||||
)
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
groups: Mapped[list["UserGroup"]] = relationship(
|
||||
"UserGroup",
|
||||
secondary="persona__user_group",
|
||||
@ -1360,6 +1367,17 @@ class Persona__UserGroup(Base):
|
||||
)
|
||||
|
||||
|
||||
class LLMProvider__UserGroup(Base):
|
||||
__tablename__ = "llm_provider__user_group"
|
||||
|
||||
llm_provider_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("llm_provider.id"), primary_key=True
|
||||
)
|
||||
user_group_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user_group.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class DocumentSet__UserGroup(Base):
|
||||
__tablename__ = "document_set__user_group"
|
||||
|
||||
|
@ -103,6 +103,7 @@ def port_api_key_to_postgres() -> None:
|
||||
default_model_name=default_model_name,
|
||||
fast_default_model_name=default_fast_model_name,
|
||||
model_names=None,
|
||||
is_public=True,
|
||||
)
|
||||
llm_provider = upsert_llm_provider(db_session, llm_provider_upsert)
|
||||
update_default_provider(db_session, llm_provider.id)
|
||||
|
@ -70,6 +70,7 @@ def load_llm_providers(db_session: Session) -> None:
|
||||
FAST_GEN_AI_MODEL_VERSION or well_known_provider.default_fast_model
|
||||
),
|
||||
model_names=model_names,
|
||||
is_public=True,
|
||||
)
|
||||
llm_provider = upsert_llm_provider(db_session, llm_provider_request)
|
||||
update_default_provider(db_session, llm_provider.id)
|
||||
|
@ -147,10 +147,10 @@ def set_provider_as_default(
|
||||
|
||||
@basic_router.get("/provider")
|
||||
def list_llm_provider_basics(
|
||||
_: User | None = Depends(current_user),
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
return [
|
||||
LLMProviderDescriptor.from_model(llm_provider_model)
|
||||
for llm_provider_model in fetch_existing_llm_providers(db_session)
|
||||
for llm_provider_model in fetch_existing_llm_providers(db_session, user)
|
||||
]
|
||||
|
@ -60,6 +60,8 @@ class LLMProvider(BaseModel):
|
||||
custom_config: dict[str, str] | None
|
||||
default_model_name: str
|
||||
fast_default_model_name: str | None
|
||||
is_public: bool
|
||||
groups: list[int] | None = None
|
||||
|
||||
|
||||
class LLMProviderUpsertRequest(LLMProvider):
|
||||
@ -91,4 +93,6 @@ class FullLLMProvider(LLMProvider):
|
||||
or fetch_models_for_provider(llm_provider_model.provider)
|
||||
or [llm_provider_model.default_model_name]
|
||||
),
|
||||
is_public=llm_provider_model.is_public,
|
||||
groups=[group.id for group in llm_provider_model.groups],
|
||||
)
|
||||
|
@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Document
|
||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.models import LLMProvider__UserGroup
|
||||
from danswer.db.models import TokenRateLimit__UserGroup
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
@ -194,6 +195,15 @@ def _cleanup_user__user_group_relationships__no_commit(
|
||||
db_session.delete(user__user_group_relationship)
|
||||
|
||||
|
||||
def _cleanup_llm_provider__user_group_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
db_session.query(LLMProvider__UserGroup).filter(
|
||||
LLMProvider__UserGroup.user_group_id == user_group_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
|
||||
def _mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
@ -316,6 +326,9 @@ def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> Non
|
||||
|
||||
|
||||
def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
|
||||
_cleanup_llm_provider__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group.id
|
||||
)
|
||||
_cleanup_user__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group.id
|
||||
)
|
||||
|
@ -1,5 +1,6 @@
|
||||
import { LoadingAnimation } from "@/components/Loading";
|
||||
import { Button, Divider, Text } from "@tremor/react";
|
||||
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
|
||||
import {
|
||||
ArrayHelpers,
|
||||
ErrorMessage,
|
||||
@ -15,11 +16,16 @@ import {
|
||||
SubLabel,
|
||||
TextArrayField,
|
||||
TextFormField,
|
||||
BooleanFormField,
|
||||
} from "@/components/admin/connectors/Field";
|
||||
import { useState } from "react";
|
||||
import { Bubble } from "@/components/Bubble";
|
||||
import { GroupsIcon } from "@/components/icons/icons";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { useUserGroups } from "@/lib/hooks";
|
||||
import { FullLLMProvider } from "./interfaces";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import * as Yup from "yup";
|
||||
import isEqual from "lodash/isEqual";
|
||||
|
||||
@ -44,9 +50,16 @@ export function CustomLLMProviderUpdateForm({
|
||||
}) {
|
||||
const { mutate } = useSWRConfig();
|
||||
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
|
||||
// EE only
|
||||
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
|
||||
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const [testError, setTestError] = useState<string>("");
|
||||
|
||||
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
|
||||
|
||||
// Define the initial values based on the provider's requirements
|
||||
const initialValues = {
|
||||
name: existingLlmProvider?.name ?? "",
|
||||
@ -61,6 +74,8 @@ export function CustomLLMProviderUpdateForm({
|
||||
custom_config_list: existingLlmProvider?.custom_config
|
||||
? Object.entries(existingLlmProvider.custom_config)
|
||||
: [],
|
||||
is_public: existingLlmProvider?.is_public ?? true,
|
||||
groups: existingLlmProvider?.groups ?? [],
|
||||
};
|
||||
|
||||
// Setup validation schema if required
|
||||
@ -74,6 +89,9 @@ export function CustomLLMProviderUpdateForm({
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
fast_default_model_name: Yup.string().nullable(),
|
||||
custom_config_list: Yup.array(),
|
||||
// EE Only
|
||||
is_public: Yup.boolean().required(),
|
||||
groups: Yup.array().of(Yup.number()),
|
||||
});
|
||||
|
||||
return (
|
||||
@ -97,6 +115,9 @@ export function CustomLLMProviderUpdateForm({
|
||||
return;
|
||||
}
|
||||
|
||||
// don't set groups if marked as public
|
||||
const groups = values.is_public ? [] : values.groups;
|
||||
|
||||
// test the configuration
|
||||
if (!isEqual(values, initialValues)) {
|
||||
setIsTesting(true);
|
||||
@ -188,93 +209,97 @@ export function CustomLLMProviderUpdateForm({
|
||||
setSubmitting(false);
|
||||
}}
|
||||
>
|
||||
{({ values }) => (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="name"
|
||||
label="Display Name"
|
||||
subtext="A name which you can use to identify this provider when selecting it in the UI."
|
||||
placeholder="Display Name"
|
||||
/>
|
||||
{({ values, setFieldValue }) => {
|
||||
return (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="name"
|
||||
label="Display Name"
|
||||
subtext="A name which you can use to identify this provider when selecting it in the UI."
|
||||
placeholder="Display Name"
|
||||
/>
|
||||
|
||||
<Divider />
|
||||
<Divider />
|
||||
|
||||
<TextFormField
|
||||
name="provider"
|
||||
label="Provider Name"
|
||||
subtext={
|
||||
<TextFormField
|
||||
name="provider"
|
||||
label="Provider Name"
|
||||
subtext={
|
||||
<>
|
||||
Should be one of the providers listed at{" "}
|
||||
<a
|
||||
target="_blank"
|
||||
href="https://docs.litellm.ai/docs/providers"
|
||||
className="text-link"
|
||||
>
|
||||
https://docs.litellm.ai/docs/providers
|
||||
</a>
|
||||
.
|
||||
</>
|
||||
}
|
||||
placeholder="Name of the custom provider"
|
||||
/>
|
||||
|
||||
<Divider />
|
||||
|
||||
<SubLabel>
|
||||
Fill in the following as is needed. Refer to the LiteLLM
|
||||
documentation for the model provider name specified above in order
|
||||
to determine which fields are required.
|
||||
</SubLabel>
|
||||
|
||||
<TextFormField
|
||||
name="api_key"
|
||||
label="[Optional] API Key"
|
||||
placeholder="API Key"
|
||||
type="password"
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
name="api_base"
|
||||
label="[Optional] API Base"
|
||||
placeholder="API Base"
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
name="api_version"
|
||||
label="[Optional] API Version"
|
||||
placeholder="API Version"
|
||||
/>
|
||||
|
||||
<Label>[Optional] Custom Configs</Label>
|
||||
<SubLabel>
|
||||
<>
|
||||
Should be one of the providers listed at{" "}
|
||||
<a
|
||||
target="_blank"
|
||||
href="https://docs.litellm.ai/docs/providers"
|
||||
className="text-link"
|
||||
>
|
||||
https://docs.litellm.ai/docs/providers
|
||||
</a>
|
||||
.
|
||||
<div>
|
||||
Additional configurations needed by the model provider. Are
|
||||
passed to litellm via environment variables.
|
||||
</div>
|
||||
|
||||
<div className="mt-2">
|
||||
For example, when configuring the Cloudflare provider, you
|
||||
would need to set `CLOUDFLARE_ACCOUNT_ID` as the key and your
|
||||
Cloudflare account ID as the value.
|
||||
</div>
|
||||
</>
|
||||
}
|
||||
placeholder="Name of the custom provider"
|
||||
/>
|
||||
</SubLabel>
|
||||
|
||||
<Divider />
|
||||
|
||||
<SubLabel>
|
||||
Fill in the following as is needed. Refer to the LiteLLM
|
||||
documentation for the model provider name specified above in order
|
||||
to determine which fields are required.
|
||||
</SubLabel>
|
||||
|
||||
<TextFormField
|
||||
name="api_key"
|
||||
label="[Optional] API Key"
|
||||
placeholder="API Key"
|
||||
type="password"
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
name="api_base"
|
||||
label="[Optional] API Base"
|
||||
placeholder="API Base"
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
name="api_version"
|
||||
label="[Optional] API Version"
|
||||
placeholder="API Version"
|
||||
/>
|
||||
|
||||
<Label>[Optional] Custom Configs</Label>
|
||||
<SubLabel>
|
||||
<>
|
||||
<div>
|
||||
Additional configurations needed by the model provider. Are
|
||||
passed to litellm via environment variables.
|
||||
</div>
|
||||
|
||||
<div className="mt-2">
|
||||
For example, when configuring the Cloudflare provider, you would
|
||||
need to set `CLOUDFLARE_ACCOUNT_ID` as the key and your
|
||||
Cloudflare account ID as the value.
|
||||
</div>
|
||||
</>
|
||||
</SubLabel>
|
||||
|
||||
<FieldArray
|
||||
name="custom_config_list"
|
||||
render={(arrayHelpers: ArrayHelpers<any[]>) => (
|
||||
<div>
|
||||
{values.custom_config_list.map((_, index) => {
|
||||
return (
|
||||
<div key={index} className={index === 0 ? "mt-2" : "mt-6"}>
|
||||
<div className="flex">
|
||||
<div className="w-full mr-6 border border-border p-3 rounded">
|
||||
<div>
|
||||
<Label>Key</Label>
|
||||
<Field
|
||||
name={`custom_config_list[${index}][0]`}
|
||||
className={`
|
||||
<FieldArray
|
||||
name="custom_config_list"
|
||||
render={(arrayHelpers: ArrayHelpers<any[]>) => (
|
||||
<div>
|
||||
{values.custom_config_list.map((_, index) => {
|
||||
return (
|
||||
<div
|
||||
key={index}
|
||||
className={index === 0 ? "mt-2" : "mt-6"}
|
||||
>
|
||||
<div className="flex">
|
||||
<div className="w-full mr-6 border border-border p-3 rounded">
|
||||
<div>
|
||||
<Label>Key</Label>
|
||||
<Field
|
||||
name={`custom_config_list[${index}][0]`}
|
||||
className={`
|
||||
border
|
||||
border-border
|
||||
bg-background
|
||||
@ -284,20 +309,20 @@ export function CustomLLMProviderUpdateForm({
|
||||
px-3
|
||||
mr-4
|
||||
`}
|
||||
autoComplete="off"
|
||||
/>
|
||||
<ErrorMessage
|
||||
name={`custom_config_list[${index}][0]`}
|
||||
component="div"
|
||||
className="text-error text-sm mt-1"
|
||||
/>
|
||||
</div>
|
||||
autoComplete="off"
|
||||
/>
|
||||
<ErrorMessage
|
||||
name={`custom_config_list[${index}][0]`}
|
||||
component="div"
|
||||
className="text-error text-sm mt-1"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="mt-3">
|
||||
<Label>Value</Label>
|
||||
<Field
|
||||
name={`custom_config_list[${index}][1]`}
|
||||
className={`
|
||||
<div className="mt-3">
|
||||
<Label>Value</Label>
|
||||
<Field
|
||||
name={`custom_config_list[${index}][1]`}
|
||||
className={`
|
||||
border
|
||||
border-border
|
||||
bg-background
|
||||
@ -307,121 +332,190 @@ export function CustomLLMProviderUpdateForm({
|
||||
px-3
|
||||
mr-4
|
||||
`}
|
||||
autoComplete="off"
|
||||
/>
|
||||
<ErrorMessage
|
||||
name={`custom_config_list[${index}][1]`}
|
||||
component="div"
|
||||
className="text-error text-sm mt-1"
|
||||
autoComplete="off"
|
||||
/>
|
||||
<ErrorMessage
|
||||
name={`custom_config_list[${index}][1]`}
|
||||
component="div"
|
||||
className="text-error text-sm mt-1"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="my-auto">
|
||||
<FiX
|
||||
className="my-auto w-10 h-10 cursor-pointer hover:bg-hover rounded p-2"
|
||||
onClick={() => arrayHelpers.remove(index)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="my-auto">
|
||||
<FiX
|
||||
className="my-auto w-10 h-10 cursor-pointer hover:bg-hover rounded p-2"
|
||||
onClick={() => arrayHelpers.remove(index)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
);
|
||||
})}
|
||||
|
||||
<Button
|
||||
onClick={() => {
|
||||
arrayHelpers.push(["", ""]);
|
||||
}}
|
||||
className="mt-3"
|
||||
color="green"
|
||||
size="xs"
|
||||
type="button"
|
||||
icon={FiPlus}
|
||||
>
|
||||
Add New
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
<Button
|
||||
onClick={() => {
|
||||
arrayHelpers.push(["", ""]);
|
||||
}}
|
||||
className="mt-3"
|
||||
color="green"
|
||||
size="xs"
|
||||
type="button"
|
||||
icon={FiPlus}
|
||||
>
|
||||
Add New
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
|
||||
<Divider />
|
||||
<Divider />
|
||||
|
||||
<TextArrayField
|
||||
name="model_names"
|
||||
label="Model Names"
|
||||
values={values}
|
||||
subtext={`List the individual models that you want to make
|
||||
<TextArrayField
|
||||
name="model_names"
|
||||
label="Model Names"
|
||||
values={values}
|
||||
subtext={`List the individual models that you want to make
|
||||
available as a part of this provider. At least one must be specified.
|
||||
As an example, for OpenAI one model might be "gpt-4".`}
|
||||
/>
|
||||
/>
|
||||
|
||||
<Divider />
|
||||
<Divider />
|
||||
|
||||
<TextFormField
|
||||
name="default_model_name"
|
||||
subtext={`
|
||||
<TextFormField
|
||||
name="default_model_name"
|
||||
subtext={`
|
||||
The model to use by default for this provider unless
|
||||
otherwise specified. Must be one of the models listed
|
||||
above.`}
|
||||
label="Default Model"
|
||||
placeholder="E.g. gpt-4"
|
||||
/>
|
||||
label="Default Model"
|
||||
placeholder="E.g. gpt-4"
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
name="fast_default_model_name"
|
||||
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
|
||||
<TextFormField
|
||||
name="fast_default_model_name"
|
||||
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
|
||||
for this provider. If not set, will use
|
||||
the Default Model configured above.`}
|
||||
label="[Optional] Fast Model"
|
||||
placeholder="E.g. gpt-4"
|
||||
/>
|
||||
label="[Optional] Fast Model"
|
||||
placeholder="E.g. gpt-4"
|
||||
/>
|
||||
|
||||
<Divider />
|
||||
<Divider />
|
||||
|
||||
<div>
|
||||
{/* NOTE: this is above the test button to make sure it's visible */}
|
||||
{testError && <Text className="text-error mt-2">{testError}</Text>}
|
||||
<AdvancedOptionsToggle
|
||||
showAdvancedOptions={showAdvancedOptions}
|
||||
setShowAdvancedOptions={setShowAdvancedOptions}
|
||||
/>
|
||||
|
||||
<div className="flex w-full mt-4">
|
||||
<Button type="submit" size="xs">
|
||||
{isTesting ? (
|
||||
<LoadingAnimation text="Testing" />
|
||||
) : existingLlmProvider ? (
|
||||
"Update"
|
||||
) : (
|
||||
"Enable"
|
||||
{showAdvancedOptions && (
|
||||
<>
|
||||
{isPaidEnterpriseFeaturesEnabled && userGroups && (
|
||||
<>
|
||||
<BooleanFormField
|
||||
small
|
||||
noPadding
|
||||
alignTop
|
||||
name="is_public"
|
||||
label="Is Public?"
|
||||
subtext="If set, this LLM Provider will be available to all users. If not, only the specified User Groups will be able to use it."
|
||||
/>
|
||||
|
||||
{userGroups &&
|
||||
userGroups.length > 0 &&
|
||||
!values.is_public && (
|
||||
<div>
|
||||
<Text>
|
||||
Select which User Groups should have access to this
|
||||
LLM Provider.
|
||||
</Text>
|
||||
<div className="flex flex-wrap gap-2 mt-2">
|
||||
{userGroups.map((userGroup) => {
|
||||
const isSelected = values.groups.includes(
|
||||
userGroup.id
|
||||
);
|
||||
return (
|
||||
<Bubble
|
||||
key={userGroup.id}
|
||||
isSelected={isSelected}
|
||||
onClick={() => {
|
||||
if (isSelected) {
|
||||
setFieldValue(
|
||||
"groups",
|
||||
values.groups.filter(
|
||||
(id) => id !== userGroup.id
|
||||
)
|
||||
);
|
||||
} else {
|
||||
setFieldValue("groups", [
|
||||
...values.groups,
|
||||
userGroup.id,
|
||||
]);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div className="flex">
|
||||
<GroupsIcon />
|
||||
<div className="ml-1">{userGroup.name}</div>
|
||||
</div>
|
||||
</Bubble>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
{existingLlmProvider && (
|
||||
<Button
|
||||
type="button"
|
||||
color="red"
|
||||
className="ml-3"
|
||||
size="xs"
|
||||
icon={FiTrash}
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}`,
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
alert(`Failed to delete provider: ${errorMsg}`);
|
||||
return;
|
||||
}
|
||||
</>
|
||||
)}
|
||||
|
||||
mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
onClose();
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
<div>
|
||||
{/* NOTE: this is above the test button to make sure it's visible */}
|
||||
{testError && (
|
||||
<Text className="text-error mt-2">{testError}</Text>
|
||||
)}
|
||||
|
||||
<div className="flex w-full mt-4">
|
||||
<Button type="submit" size="xs">
|
||||
{isTesting ? (
|
||||
<LoadingAnimation text="Testing" />
|
||||
) : existingLlmProvider ? (
|
||||
"Update"
|
||||
) : (
|
||||
"Enable"
|
||||
)}
|
||||
</Button>
|
||||
{existingLlmProvider && (
|
||||
<Button
|
||||
type="button"
|
||||
color="red"
|
||||
className="ml-3"
|
||||
size="xs"
|
||||
icon={FiTrash}
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}`,
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
alert(`Failed to delete provider: ${errorMsg}`);
|
||||
return;
|
||||
}
|
||||
|
||||
mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
onClose();
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Form>
|
||||
);
|
||||
}}
|
||||
</Formik>
|
||||
);
|
||||
}
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { LoadingAnimation } from "@/components/Loading";
|
||||
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
|
||||
import { Button, Divider, Text } from "@tremor/react";
|
||||
import { Form, Formik } from "formik";
|
||||
import { FiTrash } from "react-icons/fi";
|
||||
@ -6,11 +7,16 @@ import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
|
||||
import {
|
||||
SelectorFormField,
|
||||
TextFormField,
|
||||
BooleanFormField,
|
||||
} from "@/components/admin/connectors/Field";
|
||||
import { useState } from "react";
|
||||
import { Bubble } from "@/components/Bubble";
|
||||
import { GroupsIcon } from "@/components/icons/icons";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { useUserGroups } from "@/lib/hooks";
|
||||
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import * as Yup from "yup";
|
||||
import isEqual from "lodash/isEqual";
|
||||
|
||||
@ -29,9 +35,16 @@ export function LLMProviderUpdateForm({
|
||||
}) {
|
||||
const { mutate } = useSWRConfig();
|
||||
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
|
||||
// EE only
|
||||
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
|
||||
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const [testError, setTestError] = useState<string>("");
|
||||
|
||||
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
|
||||
|
||||
// Define the initial values based on the provider's requirements
|
||||
const initialValues = {
|
||||
name: existingLlmProvider?.name ?? "",
|
||||
@ -54,6 +67,8 @@ export function LLMProviderUpdateForm({
|
||||
},
|
||||
{} as { [key: string]: string }
|
||||
),
|
||||
is_public: existingLlmProvider?.is_public ?? true,
|
||||
groups: existingLlmProvider?.groups ?? [],
|
||||
};
|
||||
|
||||
const [validatedConfig, setValidatedConfig] = useState(
|
||||
@ -91,6 +106,9 @@ export function LLMProviderUpdateForm({
|
||||
: {}),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
fast_default_model_name: Yup.string().nullable(),
|
||||
// EE Only
|
||||
is_public: Yup.boolean().required(),
|
||||
groups: Yup.array().of(Yup.number()),
|
||||
});
|
||||
|
||||
return (
|
||||
@ -193,7 +211,7 @@ export function LLMProviderUpdateForm({
|
||||
setSubmitting(false);
|
||||
}}
|
||||
>
|
||||
{({ values }) => (
|
||||
{({ values, setFieldValue }) => (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="name"
|
||||
@ -293,6 +311,69 @@ export function LLMProviderUpdateForm({
|
||||
|
||||
<Divider />
|
||||
|
||||
<AdvancedOptionsToggle
|
||||
showAdvancedOptions={showAdvancedOptions}
|
||||
setShowAdvancedOptions={setShowAdvancedOptions}
|
||||
/>
|
||||
|
||||
{showAdvancedOptions && (
|
||||
<>
|
||||
{isPaidEnterpriseFeaturesEnabled && userGroups && (
|
||||
<>
|
||||
<BooleanFormField
|
||||
small
|
||||
noPadding
|
||||
alignTop
|
||||
name="is_public"
|
||||
label="Is Public?"
|
||||
subtext="If set, this LLM Provider will be available to all users. If not, only the specified User Groups will be able to use it."
|
||||
/>
|
||||
|
||||
{userGroups && userGroups.length > 0 && !values.is_public && (
|
||||
<div>
|
||||
<Text>
|
||||
Select which User Groups should have access to this LLM
|
||||
Provider.
|
||||
</Text>
|
||||
<div className="flex flex-wrap gap-2 mt-2">
|
||||
{userGroups.map((userGroup) => {
|
||||
const isSelected = values.groups.includes(
|
||||
userGroup.id
|
||||
);
|
||||
return (
|
||||
<Bubble
|
||||
key={userGroup.id}
|
||||
isSelected={isSelected}
|
||||
onClick={() => {
|
||||
if (isSelected) {
|
||||
setFieldValue(
|
||||
"groups",
|
||||
values.groups.filter(
|
||||
(id) => id !== userGroup.id
|
||||
)
|
||||
);
|
||||
} else {
|
||||
setFieldValue("groups", [
|
||||
...values.groups,
|
||||
userGroup.id,
|
||||
]);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div className="flex">
|
||||
<GroupsIcon />
|
||||
<div className="ml-1">{userGroup.name}</div>
|
||||
</div>
|
||||
</Bubble>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
<div>
|
||||
{/* NOTE: this is above the test button to make sure it's visible */}
|
||||
{testError && <Text className="text-error mt-2">{testError}</Text>}
|
||||
|
@ -17,6 +17,8 @@ export interface WellKnownLLMProviderDescriptor {
|
||||
llm_names: string[];
|
||||
default_model: string | null;
|
||||
default_fast_model: string | null;
|
||||
is_public: boolean;
|
||||
groups: number[];
|
||||
}
|
||||
|
||||
export interface LLMProvider {
|
||||
@ -28,6 +30,8 @@ export interface LLMProvider {
|
||||
custom_config: { [key: string]: string } | null;
|
||||
default_model_name: string;
|
||||
fast_default_model_name: string | null;
|
||||
is_public: boolean;
|
||||
groups: number[];
|
||||
}
|
||||
|
||||
export interface FullLLMProvider extends LLMProvider {
|
||||
@ -44,4 +48,6 @@ export interface LLMProviderDescriptor {
|
||||
default_model_name: string;
|
||||
fast_default_model_name: string | null;
|
||||
is_default_provider: boolean | null;
|
||||
is_public: boolean;
|
||||
groups: number[];
|
||||
}
|
||||
|
26
web/src/components/AdvancedOptionsToggle.tsx
Normal file
26
web/src/components/AdvancedOptionsToggle.tsx
Normal file
@ -0,0 +1,26 @@
|
||||
import React from "react";
|
||||
import { Button } from "@tremor/react";
|
||||
import { FiChevronDown, FiChevronRight } from "react-icons/fi";
|
||||
|
||||
interface AdvancedOptionsToggleProps {
|
||||
showAdvancedOptions: boolean;
|
||||
setShowAdvancedOptions: (show: boolean) => void;
|
||||
}
|
||||
|
||||
export function AdvancedOptionsToggle({
|
||||
showAdvancedOptions,
|
||||
setShowAdvancedOptions,
|
||||
}: AdvancedOptionsToggleProps) {
|
||||
return (
|
||||
<Button
|
||||
type="button"
|
||||
variant="light"
|
||||
size="xs"
|
||||
icon={showAdvancedOptions ? FiChevronDown : FiChevronRight}
|
||||
onClick={() => setShowAdvancedOptions(!showAdvancedOptions)}
|
||||
className="mb-4 text-xs text-text-500 hover:text-text-400"
|
||||
>
|
||||
Advanced Options
|
||||
</Button>
|
||||
);
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user