mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-27 12:29:41 +02:00
Added ability to control LLM access based on group (#1870)
* Added ability to control LLM access based on group * completed relationship deletion * cleaned up function * added comments * fixed frontend strings * mypy fixes * added case handling for deletion of user groups * hidden advanced options now * removed unnecessary code
This commit is contained in:
@@ -0,0 +1,41 @@
|
|||||||
|
"""add_llm_group_permissions_control
|
||||||
|
|
||||||
|
Revision ID: 795b20b85b4b
|
||||||
|
Revises: 05c07bf07c00
|
||||||
|
Create Date: 2024-07-19 11:54:35.701558
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
revision = "795b20b85b4b"
|
||||||
|
down_revision = "05c07bf07c00"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"llm_provider__user_group",
|
||||||
|
sa.Column("llm_provider_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("user_group_id", sa.Integer(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["llm_provider_id"],
|
||||||
|
["llm_provider.id"],
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["user_group_id"],
|
||||||
|
["user_group.id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("llm_provider_id", "user_group_id"),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"llm_provider",
|
||||||
|
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="true"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("llm_provider__user_group")
|
||||||
|
op.drop_column("llm_provider", "is_public")
|
@@ -1,15 +1,42 @@
|
|||||||
from sqlalchemy import delete
|
from sqlalchemy import delete
|
||||||
|
from sqlalchemy import or_
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||||
|
from danswer.db.models import LLMProvider__UserGroup
|
||||||
|
from danswer.db.models import User
|
||||||
|
from danswer.db.models import User__UserGroup
|
||||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||||
from danswer.server.manage.llm.models import FullLLMProvider
|
from danswer.server.manage.llm.models import FullLLMProvider
|
||||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||||
|
|
||||||
|
|
||||||
|
def update_group_llm_provider_relationships(
|
||||||
|
llm_provider_id: int,
|
||||||
|
group_ids: list[int] | None,
|
||||||
|
db_session: Session,
|
||||||
|
) -> None:
|
||||||
|
# Delete existing relationships
|
||||||
|
db_session.query(LLMProvider__UserGroup).filter(
|
||||||
|
LLMProvider__UserGroup.llm_provider_id == llm_provider_id
|
||||||
|
).delete(synchronize_session="fetch")
|
||||||
|
|
||||||
|
# Add new relationships from given group_ids
|
||||||
|
if group_ids:
|
||||||
|
new_relationships = [
|
||||||
|
LLMProvider__UserGroup(
|
||||||
|
llm_provider_id=llm_provider_id,
|
||||||
|
user_group_id=group_id,
|
||||||
|
)
|
||||||
|
for group_id in group_ids
|
||||||
|
]
|
||||||
|
db_session.add_all(new_relationships)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
|
||||||
def upsert_cloud_embedding_provider(
|
def upsert_cloud_embedding_provider(
|
||||||
db_session: Session, provider: CloudEmbeddingProviderCreationRequest
|
db_session: Session, provider: CloudEmbeddingProviderCreationRequest
|
||||||
) -> CloudEmbeddingProvider:
|
) -> CloudEmbeddingProvider:
|
||||||
@@ -36,36 +63,33 @@ def upsert_llm_provider(
|
|||||||
existing_llm_provider = db_session.scalar(
|
existing_llm_provider = db_session.scalar(
|
||||||
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
|
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
|
||||||
)
|
)
|
||||||
if existing_llm_provider:
|
|
||||||
existing_llm_provider.provider = llm_provider.provider
|
|
||||||
existing_llm_provider.api_key = llm_provider.api_key
|
|
||||||
existing_llm_provider.api_base = llm_provider.api_base
|
|
||||||
existing_llm_provider.api_version = llm_provider.api_version
|
|
||||||
existing_llm_provider.custom_config = llm_provider.custom_config
|
|
||||||
existing_llm_provider.default_model_name = llm_provider.default_model_name
|
|
||||||
existing_llm_provider.fast_default_model_name = (
|
|
||||||
llm_provider.fast_default_model_name
|
|
||||||
)
|
|
||||||
existing_llm_provider.model_names = llm_provider.model_names
|
|
||||||
db_session.commit()
|
|
||||||
return FullLLMProvider.from_model(existing_llm_provider)
|
|
||||||
# if it does not exist, create a new entry
|
|
||||||
llm_provider_model = LLMProviderModel(
|
|
||||||
name=llm_provider.name,
|
|
||||||
provider=llm_provider.provider,
|
|
||||||
api_key=llm_provider.api_key,
|
|
||||||
api_base=llm_provider.api_base,
|
|
||||||
api_version=llm_provider.api_version,
|
|
||||||
custom_config=llm_provider.custom_config,
|
|
||||||
default_model_name=llm_provider.default_model_name,
|
|
||||||
fast_default_model_name=llm_provider.fast_default_model_name,
|
|
||||||
model_names=llm_provider.model_names,
|
|
||||||
is_default_provider=None,
|
|
||||||
)
|
|
||||||
db_session.add(llm_provider_model)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
return FullLLMProvider.from_model(llm_provider_model)
|
if not existing_llm_provider:
|
||||||
|
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
|
||||||
|
db_session.add(existing_llm_provider)
|
||||||
|
|
||||||
|
existing_llm_provider.provider = llm_provider.provider
|
||||||
|
existing_llm_provider.api_key = llm_provider.api_key
|
||||||
|
existing_llm_provider.api_base = llm_provider.api_base
|
||||||
|
existing_llm_provider.api_version = llm_provider.api_version
|
||||||
|
existing_llm_provider.custom_config = llm_provider.custom_config
|
||||||
|
existing_llm_provider.default_model_name = llm_provider.default_model_name
|
||||||
|
existing_llm_provider.fast_default_model_name = llm_provider.fast_default_model_name
|
||||||
|
existing_llm_provider.model_names = llm_provider.model_names
|
||||||
|
existing_llm_provider.is_public = llm_provider.is_public
|
||||||
|
|
||||||
|
if not existing_llm_provider.id:
|
||||||
|
# If its not already in the db, we need to generate an ID by flushing
|
||||||
|
db_session.flush()
|
||||||
|
|
||||||
|
# Make sure the relationship table stays up to date
|
||||||
|
update_group_llm_provider_relationships(
|
||||||
|
llm_provider_id=existing_llm_provider.id,
|
||||||
|
group_ids=llm_provider.groups,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FullLLMProvider.from_model(existing_llm_provider)
|
||||||
|
|
||||||
|
|
||||||
def fetch_existing_embedding_providers(
|
def fetch_existing_embedding_providers(
|
||||||
@@ -74,8 +98,29 @@ def fetch_existing_embedding_providers(
|
|||||||
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
|
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
|
||||||
|
|
||||||
|
|
||||||
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
|
def fetch_existing_llm_providers(
|
||||||
return list(db_session.scalars(select(LLMProviderModel)).all())
|
db_session: Session,
|
||||||
|
user: User | None = None,
|
||||||
|
) -> list[LLMProviderModel]:
|
||||||
|
if not user:
|
||||||
|
return list(db_session.scalars(select(LLMProviderModel)).all())
|
||||||
|
stmt = select(LLMProviderModel).distinct()
|
||||||
|
user_groups_subquery = (
|
||||||
|
select(User__UserGroup.user_group_id)
|
||||||
|
.where(User__UserGroup.user_id == user.id)
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
access_conditions = or_(
|
||||||
|
LLMProviderModel.is_public,
|
||||||
|
LLMProviderModel.id.in_( # User is part of a group that has access
|
||||||
|
select(LLMProvider__UserGroup.llm_provider_id).where(
|
||||||
|
LLMProvider__UserGroup.user_group_id.in_(user_groups_subquery) # type: ignore
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
stmt = stmt.where(access_conditions)
|
||||||
|
|
||||||
|
return list(db_session.scalars(stmt).all())
|
||||||
|
|
||||||
|
|
||||||
def fetch_embedding_provider(
|
def fetch_embedding_provider(
|
||||||
@@ -119,6 +164,13 @@ def remove_embedding_provider(
|
|||||||
|
|
||||||
|
|
||||||
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
|
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
|
||||||
|
# Remove LLMProvider's dependent relationships
|
||||||
|
db_session.execute(
|
||||||
|
delete(LLMProvider__UserGroup).where(
|
||||||
|
LLMProvider__UserGroup.llm_provider_id == provider_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Remove LLMProvider
|
||||||
db_session.execute(
|
db_session.execute(
|
||||||
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||||
)
|
)
|
||||||
|
@@ -932,6 +932,13 @@ class LLMProvider(Base):
|
|||||||
|
|
||||||
# should only be set for a single provider
|
# should only be set for a single provider
|
||||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
||||||
|
# EE only
|
||||||
|
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||||
|
groups: Mapped[list["UserGroup"]] = relationship(
|
||||||
|
"UserGroup",
|
||||||
|
secondary="llm_provider__user_group",
|
||||||
|
viewonly=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CloudEmbeddingProvider(Base):
|
class CloudEmbeddingProvider(Base):
|
||||||
@@ -1109,7 +1116,6 @@ class Persona(Base):
|
|||||||
# where lower value IDs (e.g. created earlier) are displayed first
|
# where lower value IDs (e.g. created earlier) are displayed first
|
||||||
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=None)
|
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=None)
|
||||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
|
||||||
|
|
||||||
# These are only defaults, users can select from all if desired
|
# These are only defaults, users can select from all if desired
|
||||||
prompts: Mapped[list[Prompt]] = relationship(
|
prompts: Mapped[list[Prompt]] = relationship(
|
||||||
@@ -1137,6 +1143,7 @@ class Persona(Base):
|
|||||||
viewonly=True,
|
viewonly=True,
|
||||||
)
|
)
|
||||||
# EE only
|
# EE only
|
||||||
|
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||||
groups: Mapped[list["UserGroup"]] = relationship(
|
groups: Mapped[list["UserGroup"]] = relationship(
|
||||||
"UserGroup",
|
"UserGroup",
|
||||||
secondary="persona__user_group",
|
secondary="persona__user_group",
|
||||||
@@ -1360,6 +1367,17 @@ class Persona__UserGroup(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMProvider__UserGroup(Base):
|
||||||
|
__tablename__ = "llm_provider__user_group"
|
||||||
|
|
||||||
|
llm_provider_id: Mapped[int] = mapped_column(
|
||||||
|
ForeignKey("llm_provider.id"), primary_key=True
|
||||||
|
)
|
||||||
|
user_group_id: Mapped[int] = mapped_column(
|
||||||
|
ForeignKey("user_group.id"), primary_key=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DocumentSet__UserGroup(Base):
|
class DocumentSet__UserGroup(Base):
|
||||||
__tablename__ = "document_set__user_group"
|
__tablename__ = "document_set__user_group"
|
||||||
|
|
||||||
|
@@ -103,6 +103,7 @@ def port_api_key_to_postgres() -> None:
|
|||||||
default_model_name=default_model_name,
|
default_model_name=default_model_name,
|
||||||
fast_default_model_name=default_fast_model_name,
|
fast_default_model_name=default_fast_model_name,
|
||||||
model_names=None,
|
model_names=None,
|
||||||
|
is_public=True,
|
||||||
)
|
)
|
||||||
llm_provider = upsert_llm_provider(db_session, llm_provider_upsert)
|
llm_provider = upsert_llm_provider(db_session, llm_provider_upsert)
|
||||||
update_default_provider(db_session, llm_provider.id)
|
update_default_provider(db_session, llm_provider.id)
|
||||||
|
@@ -70,6 +70,7 @@ def load_llm_providers(db_session: Session) -> None:
|
|||||||
FAST_GEN_AI_MODEL_VERSION or well_known_provider.default_fast_model
|
FAST_GEN_AI_MODEL_VERSION or well_known_provider.default_fast_model
|
||||||
),
|
),
|
||||||
model_names=model_names,
|
model_names=model_names,
|
||||||
|
is_public=True,
|
||||||
)
|
)
|
||||||
llm_provider = upsert_llm_provider(db_session, llm_provider_request)
|
llm_provider = upsert_llm_provider(db_session, llm_provider_request)
|
||||||
update_default_provider(db_session, llm_provider.id)
|
update_default_provider(db_session, llm_provider.id)
|
||||||
|
@@ -147,10 +147,10 @@ def set_provider_as_default(
|
|||||||
|
|
||||||
@basic_router.get("/provider")
|
@basic_router.get("/provider")
|
||||||
def list_llm_provider_basics(
|
def list_llm_provider_basics(
|
||||||
_: User | None = Depends(current_user),
|
user: User | None = Depends(current_user),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> list[LLMProviderDescriptor]:
|
) -> list[LLMProviderDescriptor]:
|
||||||
return [
|
return [
|
||||||
LLMProviderDescriptor.from_model(llm_provider_model)
|
LLMProviderDescriptor.from_model(llm_provider_model)
|
||||||
for llm_provider_model in fetch_existing_llm_providers(db_session)
|
for llm_provider_model in fetch_existing_llm_providers(db_session, user)
|
||||||
]
|
]
|
||||||
|
@@ -60,6 +60,8 @@ class LLMProvider(BaseModel):
|
|||||||
custom_config: dict[str, str] | None
|
custom_config: dict[str, str] | None
|
||||||
default_model_name: str
|
default_model_name: str
|
||||||
fast_default_model_name: str | None
|
fast_default_model_name: str | None
|
||||||
|
is_public: bool
|
||||||
|
groups: list[int] | None = None
|
||||||
|
|
||||||
|
|
||||||
class LLMProviderUpsertRequest(LLMProvider):
|
class LLMProviderUpsertRequest(LLMProvider):
|
||||||
@@ -91,4 +93,6 @@ class FullLLMProvider(LLMProvider):
|
|||||||
or fetch_models_for_provider(llm_provider_model.provider)
|
or fetch_models_for_provider(llm_provider_model.provider)
|
||||||
or [llm_provider_model.default_model_name]
|
or [llm_provider_model.default_model_name]
|
||||||
),
|
),
|
||||||
|
is_public=llm_provider_model.is_public,
|
||||||
|
groups=[group.id for group in llm_provider_model.groups],
|
||||||
)
|
)
|
||||||
|
@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
|
|||||||
from danswer.db.models import ConnectorCredentialPair
|
from danswer.db.models import ConnectorCredentialPair
|
||||||
from danswer.db.models import Document
|
from danswer.db.models import Document
|
||||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||||
|
from danswer.db.models import LLMProvider__UserGroup
|
||||||
from danswer.db.models import TokenRateLimit__UserGroup
|
from danswer.db.models import TokenRateLimit__UserGroup
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.db.models import User__UserGroup
|
from danswer.db.models import User__UserGroup
|
||||||
@@ -194,6 +195,15 @@ def _cleanup_user__user_group_relationships__no_commit(
|
|||||||
db_session.delete(user__user_group_relationship)
|
db_session.delete(user__user_group_relationship)
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup_llm_provider__user_group_relationships__no_commit(
|
||||||
|
db_session: Session, user_group_id: int
|
||||||
|
) -> None:
|
||||||
|
"""NOTE: does not commit the transaction."""
|
||||||
|
db_session.query(LLMProvider__UserGroup).filter(
|
||||||
|
LLMProvider__UserGroup.user_group_id == user_group_id
|
||||||
|
).delete(synchronize_session=False)
|
||||||
|
|
||||||
|
|
||||||
def _mark_user_group__cc_pair_relationships_outdated__no_commit(
|
def _mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||||
db_session: Session, user_group_id: int
|
db_session: Session, user_group_id: int
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -316,6 +326,9 @@ def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> Non
|
|||||||
|
|
||||||
|
|
||||||
def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
|
def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
|
||||||
|
_cleanup_llm_provider__user_group_relationships__no_commit(
|
||||||
|
db_session=db_session, user_group_id=user_group.id
|
||||||
|
)
|
||||||
_cleanup_user__user_group_relationships__no_commit(
|
_cleanup_user__user_group_relationships__no_commit(
|
||||||
db_session=db_session, user_group_id=user_group.id
|
db_session=db_session, user_group_id=user_group.id
|
||||||
)
|
)
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
import { LoadingAnimation } from "@/components/Loading";
|
import { LoadingAnimation } from "@/components/Loading";
|
||||||
import { Button, Divider, Text } from "@tremor/react";
|
import { Button, Divider, Text } from "@tremor/react";
|
||||||
|
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
|
||||||
import {
|
import {
|
||||||
ArrayHelpers,
|
ArrayHelpers,
|
||||||
ErrorMessage,
|
ErrorMessage,
|
||||||
@@ -15,11 +16,16 @@ import {
|
|||||||
SubLabel,
|
SubLabel,
|
||||||
TextArrayField,
|
TextArrayField,
|
||||||
TextFormField,
|
TextFormField,
|
||||||
|
BooleanFormField,
|
||||||
} from "@/components/admin/connectors/Field";
|
} from "@/components/admin/connectors/Field";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
|
import { Bubble } from "@/components/Bubble";
|
||||||
|
import { GroupsIcon } from "@/components/icons/icons";
|
||||||
import { useSWRConfig } from "swr";
|
import { useSWRConfig } from "swr";
|
||||||
|
import { useUserGroups } from "@/lib/hooks";
|
||||||
import { FullLLMProvider } from "./interfaces";
|
import { FullLLMProvider } from "./interfaces";
|
||||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||||
|
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||||
import * as Yup from "yup";
|
import * as Yup from "yup";
|
||||||
import isEqual from "lodash/isEqual";
|
import isEqual from "lodash/isEqual";
|
||||||
|
|
||||||
@@ -44,9 +50,16 @@ export function CustomLLMProviderUpdateForm({
|
|||||||
}) {
|
}) {
|
||||||
const { mutate } = useSWRConfig();
|
const { mutate } = useSWRConfig();
|
||||||
|
|
||||||
|
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||||
|
|
||||||
|
// EE only
|
||||||
|
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
|
||||||
|
|
||||||
const [isTesting, setIsTesting] = useState(false);
|
const [isTesting, setIsTesting] = useState(false);
|
||||||
const [testError, setTestError] = useState<string>("");
|
const [testError, setTestError] = useState<string>("");
|
||||||
|
|
||||||
|
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
|
||||||
|
|
||||||
// Define the initial values based on the provider's requirements
|
// Define the initial values based on the provider's requirements
|
||||||
const initialValues = {
|
const initialValues = {
|
||||||
name: existingLlmProvider?.name ?? "",
|
name: existingLlmProvider?.name ?? "",
|
||||||
@@ -61,6 +74,8 @@ export function CustomLLMProviderUpdateForm({
|
|||||||
custom_config_list: existingLlmProvider?.custom_config
|
custom_config_list: existingLlmProvider?.custom_config
|
||||||
? Object.entries(existingLlmProvider.custom_config)
|
? Object.entries(existingLlmProvider.custom_config)
|
||||||
: [],
|
: [],
|
||||||
|
is_public: existingLlmProvider?.is_public ?? true,
|
||||||
|
groups: existingLlmProvider?.groups ?? [],
|
||||||
};
|
};
|
||||||
|
|
||||||
// Setup validation schema if required
|
// Setup validation schema if required
|
||||||
@@ -74,6 +89,9 @@ export function CustomLLMProviderUpdateForm({
|
|||||||
default_model_name: Yup.string().required("Model name is required"),
|
default_model_name: Yup.string().required("Model name is required"),
|
||||||
fast_default_model_name: Yup.string().nullable(),
|
fast_default_model_name: Yup.string().nullable(),
|
||||||
custom_config_list: Yup.array(),
|
custom_config_list: Yup.array(),
|
||||||
|
// EE Only
|
||||||
|
is_public: Yup.boolean().required(),
|
||||||
|
groups: Yup.array().of(Yup.number()),
|
||||||
});
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -97,6 +115,9 @@ export function CustomLLMProviderUpdateForm({
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// don't set groups if marked as public
|
||||||
|
const groups = values.is_public ? [] : values.groups;
|
||||||
|
|
||||||
// test the configuration
|
// test the configuration
|
||||||
if (!isEqual(values, initialValues)) {
|
if (!isEqual(values, initialValues)) {
|
||||||
setIsTesting(true);
|
setIsTesting(true);
|
||||||
@@ -188,93 +209,97 @@ export function CustomLLMProviderUpdateForm({
|
|||||||
setSubmitting(false);
|
setSubmitting(false);
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{({ values }) => (
|
{({ values, setFieldValue }) => {
|
||||||
<Form>
|
return (
|
||||||
<TextFormField
|
<Form>
|
||||||
name="name"
|
<TextFormField
|
||||||
label="Display Name"
|
name="name"
|
||||||
subtext="A name which you can use to identify this provider when selecting it in the UI."
|
label="Display Name"
|
||||||
placeholder="Display Name"
|
subtext="A name which you can use to identify this provider when selecting it in the UI."
|
||||||
/>
|
placeholder="Display Name"
|
||||||
|
/>
|
||||||
|
|
||||||
<Divider />
|
<Divider />
|
||||||
|
|
||||||
<TextFormField
|
<TextFormField
|
||||||
name="provider"
|
name="provider"
|
||||||
label="Provider Name"
|
label="Provider Name"
|
||||||
subtext={
|
subtext={
|
||||||
|
<>
|
||||||
|
Should be one of the providers listed at{" "}
|
||||||
|
<a
|
||||||
|
target="_blank"
|
||||||
|
href="https://docs.litellm.ai/docs/providers"
|
||||||
|
className="text-link"
|
||||||
|
>
|
||||||
|
https://docs.litellm.ai/docs/providers
|
||||||
|
</a>
|
||||||
|
.
|
||||||
|
</>
|
||||||
|
}
|
||||||
|
placeholder="Name of the custom provider"
|
||||||
|
/>
|
||||||
|
|
||||||
|
<Divider />
|
||||||
|
|
||||||
|
<SubLabel>
|
||||||
|
Fill in the following as is needed. Refer to the LiteLLM
|
||||||
|
documentation for the model provider name specified above in order
|
||||||
|
to determine which fields are required.
|
||||||
|
</SubLabel>
|
||||||
|
|
||||||
|
<TextFormField
|
||||||
|
name="api_key"
|
||||||
|
label="[Optional] API Key"
|
||||||
|
placeholder="API Key"
|
||||||
|
type="password"
|
||||||
|
/>
|
||||||
|
|
||||||
|
<TextFormField
|
||||||
|
name="api_base"
|
||||||
|
label="[Optional] API Base"
|
||||||
|
placeholder="API Base"
|
||||||
|
/>
|
||||||
|
|
||||||
|
<TextFormField
|
||||||
|
name="api_version"
|
||||||
|
label="[Optional] API Version"
|
||||||
|
placeholder="API Version"
|
||||||
|
/>
|
||||||
|
|
||||||
|
<Label>[Optional] Custom Configs</Label>
|
||||||
|
<SubLabel>
|
||||||
<>
|
<>
|
||||||
Should be one of the providers listed at{" "}
|
<div>
|
||||||
<a
|
Additional configurations needed by the model provider. Are
|
||||||
target="_blank"
|
passed to litellm via environment variables.
|
||||||
href="https://docs.litellm.ai/docs/providers"
|
</div>
|
||||||
className="text-link"
|
|
||||||
>
|
<div className="mt-2">
|
||||||
https://docs.litellm.ai/docs/providers
|
For example, when configuring the Cloudflare provider, you
|
||||||
</a>
|
would need to set `CLOUDFLARE_ACCOUNT_ID` as the key and your
|
||||||
.
|
Cloudflare account ID as the value.
|
||||||
|
</div>
|
||||||
</>
|
</>
|
||||||
}
|
</SubLabel>
|
||||||
placeholder="Name of the custom provider"
|
|
||||||
/>
|
|
||||||
|
|
||||||
<Divider />
|
<FieldArray
|
||||||
|
name="custom_config_list"
|
||||||
<SubLabel>
|
render={(arrayHelpers: ArrayHelpers<any[]>) => (
|
||||||
Fill in the following as is needed. Refer to the LiteLLM
|
<div>
|
||||||
documentation for the model provider name specified above in order
|
{values.custom_config_list.map((_, index) => {
|
||||||
to determine which fields are required.
|
return (
|
||||||
</SubLabel>
|
<div
|
||||||
|
key={index}
|
||||||
<TextFormField
|
className={index === 0 ? "mt-2" : "mt-6"}
|
||||||
name="api_key"
|
>
|
||||||
label="[Optional] API Key"
|
<div className="flex">
|
||||||
placeholder="API Key"
|
<div className="w-full mr-6 border border-border p-3 rounded">
|
||||||
type="password"
|
<div>
|
||||||
/>
|
<Label>Key</Label>
|
||||||
|
<Field
|
||||||
<TextFormField
|
name={`custom_config_list[${index}][0]`}
|
||||||
name="api_base"
|
className={`
|
||||||
label="[Optional] API Base"
|
|
||||||
placeholder="API Base"
|
|
||||||
/>
|
|
||||||
|
|
||||||
<TextFormField
|
|
||||||
name="api_version"
|
|
||||||
label="[Optional] API Version"
|
|
||||||
placeholder="API Version"
|
|
||||||
/>
|
|
||||||
|
|
||||||
<Label>[Optional] Custom Configs</Label>
|
|
||||||
<SubLabel>
|
|
||||||
<>
|
|
||||||
<div>
|
|
||||||
Additional configurations needed by the model provider. Are
|
|
||||||
passed to litellm via environment variables.
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div className="mt-2">
|
|
||||||
For example, when configuring the Cloudflare provider, you would
|
|
||||||
need to set `CLOUDFLARE_ACCOUNT_ID` as the key and your
|
|
||||||
Cloudflare account ID as the value.
|
|
||||||
</div>
|
|
||||||
</>
|
|
||||||
</SubLabel>
|
|
||||||
|
|
||||||
<FieldArray
|
|
||||||
name="custom_config_list"
|
|
||||||
render={(arrayHelpers: ArrayHelpers<any[]>) => (
|
|
||||||
<div>
|
|
||||||
{values.custom_config_list.map((_, index) => {
|
|
||||||
return (
|
|
||||||
<div key={index} className={index === 0 ? "mt-2" : "mt-6"}>
|
|
||||||
<div className="flex">
|
|
||||||
<div className="w-full mr-6 border border-border p-3 rounded">
|
|
||||||
<div>
|
|
||||||
<Label>Key</Label>
|
|
||||||
<Field
|
|
||||||
name={`custom_config_list[${index}][0]`}
|
|
||||||
className={`
|
|
||||||
border
|
border
|
||||||
border-border
|
border-border
|
||||||
bg-background
|
bg-background
|
||||||
@@ -284,20 +309,20 @@ export function CustomLLMProviderUpdateForm({
|
|||||||
px-3
|
px-3
|
||||||
mr-4
|
mr-4
|
||||||
`}
|
`}
|
||||||
autoComplete="off"
|
autoComplete="off"
|
||||||
/>
|
/>
|
||||||
<ErrorMessage
|
<ErrorMessage
|
||||||
name={`custom_config_list[${index}][0]`}
|
name={`custom_config_list[${index}][0]`}
|
||||||
component="div"
|
component="div"
|
||||||
className="text-error text-sm mt-1"
|
className="text-error text-sm mt-1"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="mt-3">
|
<div className="mt-3">
|
||||||
<Label>Value</Label>
|
<Label>Value</Label>
|
||||||
<Field
|
<Field
|
||||||
name={`custom_config_list[${index}][1]`}
|
name={`custom_config_list[${index}][1]`}
|
||||||
className={`
|
className={`
|
||||||
border
|
border
|
||||||
border-border
|
border-border
|
||||||
bg-background
|
bg-background
|
||||||
@@ -307,121 +332,190 @@ export function CustomLLMProviderUpdateForm({
|
|||||||
px-3
|
px-3
|
||||||
mr-4
|
mr-4
|
||||||
`}
|
`}
|
||||||
autoComplete="off"
|
autoComplete="off"
|
||||||
/>
|
/>
|
||||||
<ErrorMessage
|
<ErrorMessage
|
||||||
name={`custom_config_list[${index}][1]`}
|
name={`custom_config_list[${index}][1]`}
|
||||||
component="div"
|
component="div"
|
||||||
className="text-error text-sm mt-1"
|
className="text-error text-sm mt-1"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="my-auto">
|
||||||
|
<FiX
|
||||||
|
className="my-auto w-10 h-10 cursor-pointer hover:bg-hover rounded p-2"
|
||||||
|
onClick={() => arrayHelpers.remove(index)}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div className="my-auto">
|
|
||||||
<FiX
|
|
||||||
className="my-auto w-10 h-10 cursor-pointer hover:bg-hover rounded p-2"
|
|
||||||
onClick={() => arrayHelpers.remove(index)}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
);
|
||||||
);
|
})}
|
||||||
})}
|
|
||||||
|
|
||||||
<Button
|
<Button
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
arrayHelpers.push(["", ""]);
|
arrayHelpers.push(["", ""]);
|
||||||
}}
|
}}
|
||||||
className="mt-3"
|
className="mt-3"
|
||||||
color="green"
|
color="green"
|
||||||
size="xs"
|
size="xs"
|
||||||
type="button"
|
type="button"
|
||||||
icon={FiPlus}
|
icon={FiPlus}
|
||||||
>
|
>
|
||||||
Add New
|
Add New
|
||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<Divider />
|
<Divider />
|
||||||
|
|
||||||
<TextArrayField
|
<TextArrayField
|
||||||
name="model_names"
|
name="model_names"
|
||||||
label="Model Names"
|
label="Model Names"
|
||||||
values={values}
|
values={values}
|
||||||
subtext={`List the individual models that you want to make
|
subtext={`List the individual models that you want to make
|
||||||
available as a part of this provider. At least one must be specified.
|
available as a part of this provider. At least one must be specified.
|
||||||
As an example, for OpenAI one model might be "gpt-4".`}
|
As an example, for OpenAI one model might be "gpt-4".`}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<Divider />
|
<Divider />
|
||||||
|
|
||||||
<TextFormField
|
<TextFormField
|
||||||
name="default_model_name"
|
name="default_model_name"
|
||||||
subtext={`
|
subtext={`
|
||||||
The model to use by default for this provider unless
|
The model to use by default for this provider unless
|
||||||
otherwise specified. Must be one of the models listed
|
otherwise specified. Must be one of the models listed
|
||||||
above.`}
|
above.`}
|
||||||
label="Default Model"
|
label="Default Model"
|
||||||
placeholder="E.g. gpt-4"
|
placeholder="E.g. gpt-4"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<TextFormField
|
<TextFormField
|
||||||
name="fast_default_model_name"
|
name="fast_default_model_name"
|
||||||
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
|
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
|
||||||
for this provider. If not set, will use
|
for this provider. If not set, will use
|
||||||
the Default Model configured above.`}
|
the Default Model configured above.`}
|
||||||
label="[Optional] Fast Model"
|
label="[Optional] Fast Model"
|
||||||
placeholder="E.g. gpt-4"
|
placeholder="E.g. gpt-4"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<Divider />
|
<Divider />
|
||||||
|
|
||||||
<div>
|
<AdvancedOptionsToggle
|
||||||
{/* NOTE: this is above the test button to make sure it's visible */}
|
showAdvancedOptions={showAdvancedOptions}
|
||||||
{testError && <Text className="text-error mt-2">{testError}</Text>}
|
setShowAdvancedOptions={setShowAdvancedOptions}
|
||||||
|
/>
|
||||||
|
|
||||||
<div className="flex w-full mt-4">
|
{showAdvancedOptions && (
|
||||||
<Button type="submit" size="xs">
|
<>
|
||||||
{isTesting ? (
|
{isPaidEnterpriseFeaturesEnabled && userGroups && (
|
||||||
<LoadingAnimation text="Testing" />
|
<>
|
||||||
) : existingLlmProvider ? (
|
<BooleanFormField
|
||||||
"Update"
|
small
|
||||||
) : (
|
noPadding
|
||||||
"Enable"
|
alignTop
|
||||||
|
name="is_public"
|
||||||
|
label="Is Public?"
|
||||||
|
subtext="If set, this LLM Provider will be available to all users. If not, only the specified User Groups will be able to use it."
|
||||||
|
/>
|
||||||
|
|
||||||
|
{userGroups &&
|
||||||
|
userGroups.length > 0 &&
|
||||||
|
!values.is_public && (
|
||||||
|
<div>
|
||||||
|
<Text>
|
||||||
|
Select which User Groups should have access to this
|
||||||
|
LLM Provider.
|
||||||
|
</Text>
|
||||||
|
<div className="flex flex-wrap gap-2 mt-2">
|
||||||
|
{userGroups.map((userGroup) => {
|
||||||
|
const isSelected = values.groups.includes(
|
||||||
|
userGroup.id
|
||||||
|
);
|
||||||
|
return (
|
||||||
|
<Bubble
|
||||||
|
key={userGroup.id}
|
||||||
|
isSelected={isSelected}
|
||||||
|
onClick={() => {
|
||||||
|
if (isSelected) {
|
||||||
|
setFieldValue(
|
||||||
|
"groups",
|
||||||
|
values.groups.filter(
|
||||||
|
(id) => id !== userGroup.id
|
||||||
|
)
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
setFieldValue("groups", [
|
||||||
|
...values.groups,
|
||||||
|
userGroup.id,
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<div className="flex">
|
||||||
|
<GroupsIcon />
|
||||||
|
<div className="ml-1">{userGroup.name}</div>
|
||||||
|
</div>
|
||||||
|
</Bubble>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
)}
|
)}
|
||||||
</Button>
|
</>
|
||||||
{existingLlmProvider && (
|
)}
|
||||||
<Button
|
|
||||||
type="button"
|
|
||||||
color="red"
|
|
||||||
className="ml-3"
|
|
||||||
size="xs"
|
|
||||||
icon={FiTrash}
|
|
||||||
onClick={async () => {
|
|
||||||
const response = await fetch(
|
|
||||||
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}`,
|
|
||||||
{
|
|
||||||
method: "DELETE",
|
|
||||||
}
|
|
||||||
);
|
|
||||||
if (!response.ok) {
|
|
||||||
const errorMsg = (await response.json()).detail;
|
|
||||||
alert(`Failed to delete provider: ${errorMsg}`);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
mutate(LLM_PROVIDERS_ADMIN_URL);
|
<div>
|
||||||
onClose();
|
{/* NOTE: this is above the test button to make sure it's visible */}
|
||||||
}}
|
{testError && (
|
||||||
>
|
<Text className="text-error mt-2">{testError}</Text>
|
||||||
Delete
|
|
||||||
</Button>
|
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
<div className="flex w-full mt-4">
|
||||||
|
<Button type="submit" size="xs">
|
||||||
|
{isTesting ? (
|
||||||
|
<LoadingAnimation text="Testing" />
|
||||||
|
) : existingLlmProvider ? (
|
||||||
|
"Update"
|
||||||
|
) : (
|
||||||
|
"Enable"
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
{existingLlmProvider && (
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
color="red"
|
||||||
|
className="ml-3"
|
||||||
|
size="xs"
|
||||||
|
icon={FiTrash}
|
||||||
|
onClick={async () => {
|
||||||
|
const response = await fetch(
|
||||||
|
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}`,
|
||||||
|
{
|
||||||
|
method: "DELETE",
|
||||||
|
}
|
||||||
|
);
|
||||||
|
if (!response.ok) {
|
||||||
|
const errorMsg = (await response.json()).detail;
|
||||||
|
alert(`Failed to delete provider: ${errorMsg}`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||||
|
onClose();
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
Delete
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</Form>
|
||||||
</Form>
|
);
|
||||||
)}
|
}}
|
||||||
</Formik>
|
</Formik>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
import { LoadingAnimation } from "@/components/Loading";
|
import { LoadingAnimation } from "@/components/Loading";
|
||||||
|
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
|
||||||
import { Button, Divider, Text } from "@tremor/react";
|
import { Button, Divider, Text } from "@tremor/react";
|
||||||
import { Form, Formik } from "formik";
|
import { Form, Formik } from "formik";
|
||||||
import { FiTrash } from "react-icons/fi";
|
import { FiTrash } from "react-icons/fi";
|
||||||
@@ -6,11 +7,16 @@ import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
|
|||||||
import {
|
import {
|
||||||
SelectorFormField,
|
SelectorFormField,
|
||||||
TextFormField,
|
TextFormField,
|
||||||
|
BooleanFormField,
|
||||||
} from "@/components/admin/connectors/Field";
|
} from "@/components/admin/connectors/Field";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
|
import { Bubble } from "@/components/Bubble";
|
||||||
|
import { GroupsIcon } from "@/components/icons/icons";
|
||||||
import { useSWRConfig } from "swr";
|
import { useSWRConfig } from "swr";
|
||||||
|
import { useUserGroups } from "@/lib/hooks";
|
||||||
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
|
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
|
||||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||||
|
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||||
import * as Yup from "yup";
|
import * as Yup from "yup";
|
||||||
import isEqual from "lodash/isEqual";
|
import isEqual from "lodash/isEqual";
|
||||||
|
|
||||||
@@ -29,9 +35,16 @@ export function LLMProviderUpdateForm({
|
|||||||
}) {
|
}) {
|
||||||
const { mutate } = useSWRConfig();
|
const { mutate } = useSWRConfig();
|
||||||
|
|
||||||
|
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||||
|
|
||||||
|
// EE only
|
||||||
|
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
|
||||||
|
|
||||||
const [isTesting, setIsTesting] = useState(false);
|
const [isTesting, setIsTesting] = useState(false);
|
||||||
const [testError, setTestError] = useState<string>("");
|
const [testError, setTestError] = useState<string>("");
|
||||||
|
|
||||||
|
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
|
||||||
|
|
||||||
// Define the initial values based on the provider's requirements
|
// Define the initial values based on the provider's requirements
|
||||||
const initialValues = {
|
const initialValues = {
|
||||||
name: existingLlmProvider?.name ?? "",
|
name: existingLlmProvider?.name ?? "",
|
||||||
@@ -54,6 +67,8 @@ export function LLMProviderUpdateForm({
|
|||||||
},
|
},
|
||||||
{} as { [key: string]: string }
|
{} as { [key: string]: string }
|
||||||
),
|
),
|
||||||
|
is_public: existingLlmProvider?.is_public ?? true,
|
||||||
|
groups: existingLlmProvider?.groups ?? [],
|
||||||
};
|
};
|
||||||
|
|
||||||
const [validatedConfig, setValidatedConfig] = useState(
|
const [validatedConfig, setValidatedConfig] = useState(
|
||||||
@@ -91,6 +106,9 @@ export function LLMProviderUpdateForm({
|
|||||||
: {}),
|
: {}),
|
||||||
default_model_name: Yup.string().required("Model name is required"),
|
default_model_name: Yup.string().required("Model name is required"),
|
||||||
fast_default_model_name: Yup.string().nullable(),
|
fast_default_model_name: Yup.string().nullable(),
|
||||||
|
// EE Only
|
||||||
|
is_public: Yup.boolean().required(),
|
||||||
|
groups: Yup.array().of(Yup.number()),
|
||||||
});
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -193,7 +211,7 @@ export function LLMProviderUpdateForm({
|
|||||||
setSubmitting(false);
|
setSubmitting(false);
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{({ values }) => (
|
{({ values, setFieldValue }) => (
|
||||||
<Form>
|
<Form>
|
||||||
<TextFormField
|
<TextFormField
|
||||||
name="name"
|
name="name"
|
||||||
@@ -293,6 +311,69 @@ export function LLMProviderUpdateForm({
|
|||||||
|
|
||||||
<Divider />
|
<Divider />
|
||||||
|
|
||||||
|
<AdvancedOptionsToggle
|
||||||
|
showAdvancedOptions={showAdvancedOptions}
|
||||||
|
setShowAdvancedOptions={setShowAdvancedOptions}
|
||||||
|
/>
|
||||||
|
|
||||||
|
{showAdvancedOptions && (
|
||||||
|
<>
|
||||||
|
{isPaidEnterpriseFeaturesEnabled && userGroups && (
|
||||||
|
<>
|
||||||
|
<BooleanFormField
|
||||||
|
small
|
||||||
|
noPadding
|
||||||
|
alignTop
|
||||||
|
name="is_public"
|
||||||
|
label="Is Public?"
|
||||||
|
subtext="If set, this LLM Provider will be available to all users. If not, only the specified User Groups will be able to use it."
|
||||||
|
/>
|
||||||
|
|
||||||
|
{userGroups && userGroups.length > 0 && !values.is_public && (
|
||||||
|
<div>
|
||||||
|
<Text>
|
||||||
|
Select which User Groups should have access to this LLM
|
||||||
|
Provider.
|
||||||
|
</Text>
|
||||||
|
<div className="flex flex-wrap gap-2 mt-2">
|
||||||
|
{userGroups.map((userGroup) => {
|
||||||
|
const isSelected = values.groups.includes(
|
||||||
|
userGroup.id
|
||||||
|
);
|
||||||
|
return (
|
||||||
|
<Bubble
|
||||||
|
key={userGroup.id}
|
||||||
|
isSelected={isSelected}
|
||||||
|
onClick={() => {
|
||||||
|
if (isSelected) {
|
||||||
|
setFieldValue(
|
||||||
|
"groups",
|
||||||
|
values.groups.filter(
|
||||||
|
(id) => id !== userGroup.id
|
||||||
|
)
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
setFieldValue("groups", [
|
||||||
|
...values.groups,
|
||||||
|
userGroup.id,
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<div className="flex">
|
||||||
|
<GroupsIcon />
|
||||||
|
<div className="ml-1">{userGroup.name}</div>
|
||||||
|
</div>
|
||||||
|
</Bubble>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
<div>
|
<div>
|
||||||
{/* NOTE: this is above the test button to make sure it's visible */}
|
{/* NOTE: this is above the test button to make sure it's visible */}
|
||||||
{testError && <Text className="text-error mt-2">{testError}</Text>}
|
{testError && <Text className="text-error mt-2">{testError}</Text>}
|
||||||
|
@@ -17,6 +17,8 @@ export interface WellKnownLLMProviderDescriptor {
|
|||||||
llm_names: string[];
|
llm_names: string[];
|
||||||
default_model: string | null;
|
default_model: string | null;
|
||||||
default_fast_model: string | null;
|
default_fast_model: string | null;
|
||||||
|
is_public: boolean;
|
||||||
|
groups: number[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface LLMProvider {
|
export interface LLMProvider {
|
||||||
@@ -28,6 +30,8 @@ export interface LLMProvider {
|
|||||||
custom_config: { [key: string]: string } | null;
|
custom_config: { [key: string]: string } | null;
|
||||||
default_model_name: string;
|
default_model_name: string;
|
||||||
fast_default_model_name: string | null;
|
fast_default_model_name: string | null;
|
||||||
|
is_public: boolean;
|
||||||
|
groups: number[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface FullLLMProvider extends LLMProvider {
|
export interface FullLLMProvider extends LLMProvider {
|
||||||
@@ -44,4 +48,6 @@ export interface LLMProviderDescriptor {
|
|||||||
default_model_name: string;
|
default_model_name: string;
|
||||||
fast_default_model_name: string | null;
|
fast_default_model_name: string | null;
|
||||||
is_default_provider: boolean | null;
|
is_default_provider: boolean | null;
|
||||||
|
is_public: boolean;
|
||||||
|
groups: number[];
|
||||||
}
|
}
|
||||||
|
26
web/src/components/AdvancedOptionsToggle.tsx
Normal file
26
web/src/components/AdvancedOptionsToggle.tsx
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import React from "react";
|
||||||
|
import { Button } from "@tremor/react";
|
||||||
|
import { FiChevronDown, FiChevronRight } from "react-icons/fi";
|
||||||
|
|
||||||
|
interface AdvancedOptionsToggleProps {
|
||||||
|
showAdvancedOptions: boolean;
|
||||||
|
setShowAdvancedOptions: (show: boolean) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function AdvancedOptionsToggle({
|
||||||
|
showAdvancedOptions,
|
||||||
|
setShowAdvancedOptions,
|
||||||
|
}: AdvancedOptionsToggleProps) {
|
||||||
|
return (
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
variant="light"
|
||||||
|
size="xs"
|
||||||
|
icon={showAdvancedOptions ? FiChevronDown : FiChevronRight}
|
||||||
|
onClick={() => setShowAdvancedOptions(!showAdvancedOptions)}
|
||||||
|
className="mb-4 text-xs text-text-500 hover:text-text-400"
|
||||||
|
>
|
||||||
|
Advanced Options
|
||||||
|
</Button>
|
||||||
|
);
|
||||||
|
}
|
Reference in New Issue
Block a user