sanitize llm keys and handle updates properly (#4270)

* sanitize llm keys and handle updates properly

* fix llm provider testing

* fix test

* mypy

* fix default model editing

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
This commit is contained in:
rkuo-danswer 2025-03-19 18:13:02 -07:00 committed by GitHub
parent 5dda53eec3
commit 85ebadc8eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 146 additions and 62 deletions

View File

@ -271,6 +271,7 @@ def configure_default_api_keys(db_session: Session) -> None:
fast_default_model_name="claude-3-5-sonnet-20241022",
model_names=ANTHROPIC_MODEL_NAMES,
display_model_names=["claude-3-5-sonnet-20241022"],
api_key_changed=True,
)
try:
full_provider = upsert_llm_provider(anthropic_provider, db_session)
@ -283,7 +284,7 @@ def configure_default_api_keys(db_session: Session) -> None:
)
if OPENAI_DEFAULT_API_KEY:
open_provider = LLMProviderUpsertRequest(
openai_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
@ -291,9 +292,10 @@ def configure_default_api_keys(db_session: Session) -> None:
fast_default_model_name="gpt-4o-mini",
model_names=OPEN_AI_MODEL_NAMES,
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
api_key_changed=True,
)
try:
full_provider = upsert_llm_provider(open_provider, db_session)
full_provider = upsert_llm_provider(openai_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure OpenAI provider: {e}")

View File

@ -16,8 +16,8 @@ from onyx.db.models import User__UserGroup
from onyx.llm.utils import model_supports_image_input
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from onyx.server.manage.llm.models import FullLLMProvider
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from shared_configs.enums import EmbeddingProvider
@ -67,7 +67,7 @@ def upsert_cloud_embedding_provider(
def upsert_llm_provider(
llm_provider: LLMProviderUpsertRequest,
db_session: Session,
) -> FullLLMProvider:
) -> LLMProviderView:
existing_llm_provider = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
)
@ -98,7 +98,7 @@ def upsert_llm_provider(
group_ids=llm_provider.groups,
db_session=db_session,
)
full_llm_provider = FullLLMProvider.from_model(existing_llm_provider)
full_llm_provider = LLMProviderView.from_model(existing_llm_provider)
db_session.commit()
@ -132,6 +132,16 @@ def fetch_existing_llm_providers(
return list(db_session.scalars(stmt).all())
def fetch_existing_llm_provider(
provider_name: str, db_session: Session
) -> LLMProviderModel | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
)
return provider_model
def fetch_existing_llm_providers_for_user(
db_session: Session,
user: User | None = None,
@ -177,7 +187,7 @@ def fetch_embedding_provider(
)
def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
def fetch_default_provider(db_session: Session) -> LLMProviderView | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.is_default_provider == True # noqa: E712
@ -185,10 +195,10 @@ def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
)
if not provider_model:
return None
return FullLLMProvider.from_model(provider_model)
return LLMProviderView.from_model(provider_model)
def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None:
def fetch_default_vision_provider(db_session: Session) -> LLMProviderView | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.is_default_vision_provider == True # noqa: E712
@ -196,16 +206,18 @@ def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None
)
if not provider_model:
return None
return FullLLMProvider.from_model(provider_model)
return LLMProviderView.from_model(provider_model)
def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None:
def fetch_llm_provider_view(
db_session: Session, provider_name: str
) -> LLMProviderView | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
)
if not provider_model:
return None
return FullLLMProvider.from_model(provider_model)
return LLMProviderView.from_model(provider_model)
def remove_embedding_provider(

View File

@ -9,14 +9,14 @@ from onyx.db.engine import get_session_with_current_tenant
from onyx.db.llm import fetch_default_provider
from onyx.db.llm import fetch_default_vision_provider
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_provider
from onyx.db.llm import fetch_llm_provider_view
from onyx.db.models import Persona
from onyx.llm.chat_llm import DefaultMultiLLM
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.interfaces import LLM
from onyx.llm.override_models import LLMOverride
from onyx.llm.utils import model_supports_image_input
from onyx.server.manage.llm.models import FullLLMProvider
from onyx.server.manage.llm.models import LLMProviderView
from onyx.utils.headers import build_llm_extra_headers
from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger
@ -62,7 +62,7 @@ def get_llms_for_persona(
)
with get_session_context_manager() as db_session:
llm_provider = fetch_provider(db_session, provider_name)
llm_provider = fetch_llm_provider_view(db_session, provider_name)
if not llm_provider:
raise ValueError("No LLM provider found")
@ -106,7 +106,7 @@ def get_default_llm_with_vision(
if DISABLE_GENERATIVE_AI:
raise GenAIDisabledException()
def create_vision_llm(provider: FullLLMProvider, model: str) -> LLM:
def create_vision_llm(provider: LLMProviderView, model: str) -> LLM:
"""Helper to create an LLM if the provider supports image input."""
return get_llm(
provider=provider.provider,
@ -148,7 +148,7 @@ def get_default_llm_with_vision(
provider.default_vision_model, provider.provider
):
return create_vision_llm(
FullLLMProvider.from_model(provider), provider.default_vision_model
LLMProviderView.from_model(provider), provider.default_vision_model
)
return None

View File

@ -9,9 +9,9 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accessible_user
from onyx.db.engine import get_session
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_llm_providers_for_user
from onyx.db.llm import fetch_provider
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import update_default_vision_provider
@ -24,9 +24,9 @@ from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.llm.utils import model_supports_image_input
from onyx.llm.utils import test_llm
from onyx.server.manage.llm.models import FullLLMProvider
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import TestLLMRequest
from onyx.server.manage.llm.models import VisionProviderResponse
from onyx.utils.logger import setup_logger
@ -49,11 +49,27 @@ def fetch_llm_options(
def test_llm_configuration(
test_llm_request: TestLLMRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
"""Test regular llm and fast llm settings"""
# the api key is sanitized if we are testing a provider already in the system
test_api_key = test_llm_request.api_key
if test_llm_request.name:
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
# as it turns out the name is not editable in the UI and other code also keys off name,
# so we won't rock the boat just yet.
existing_provider = fetch_existing_llm_provider(
test_llm_request.name, db_session
)
if existing_provider:
test_api_key = existing_provider.api_key
llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.default_model_name,
api_key=test_llm_request.api_key,
api_key=test_api_key,
api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version,
custom_config=test_llm_request.custom_config,
@ -69,7 +85,7 @@ def test_llm_configuration(
fast_llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.fast_default_model_name,
api_key=test_llm_request.api_key,
api_key=test_api_key,
api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version,
custom_config=test_llm_request.custom_config,
@ -119,11 +135,17 @@ def test_default_provider(
def list_llm_providers(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[FullLLMProvider]:
return [
FullLLMProvider.from_model(llm_provider_model)
for llm_provider_model in fetch_existing_llm_providers(db_session)
]
) -> list[LLMProviderView]:
llm_provider_list: list[LLMProviderView] = []
for llm_provider_model in fetch_existing_llm_providers(db_session):
full_llm_provider = LLMProviderView.from_model(llm_provider_model)
if full_llm_provider.api_key:
full_llm_provider.api_key = (
full_llm_provider.api_key[:4] + "****" + full_llm_provider.api_key[-4:]
)
llm_provider_list.append(full_llm_provider)
return llm_provider_list
@admin_router.put("/provider")
@ -135,11 +157,11 @@ def put_llm_provider(
),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> FullLLMProvider:
) -> LLMProviderView:
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
# the result
existing_provider = fetch_provider(db_session, llm_provider.name)
existing_provider = fetch_existing_llm_provider(llm_provider.name, db_session)
if existing_provider and is_creation:
raise HTTPException(
status_code=400,
@ -161,6 +183,11 @@ def put_llm_provider(
llm_provider.fast_default_model_name
)
# the llm api key is sanitized when returned to clients, so the only time we
# should get a real key is when it is explicitly changed
if existing_provider and not llm_provider.api_key_changed:
llm_provider.api_key = existing_provider.api_key
try:
return upsert_llm_provider(
llm_provider=llm_provider,
@ -234,7 +261,7 @@ def get_vision_capable_providers(
# Only include providers with at least one vision-capable model
if vision_models:
provider_dict = FullLLMProvider.from_model(provider).model_dump()
provider_dict = LLMProviderView.from_model(provider).model_dump()
provider_dict["vision_models"] = vision_models
logger.info(
f"Vision provider: {provider.provider} with models: {vision_models}"

View File

@ -12,6 +12,7 @@ if TYPE_CHECKING:
class TestLLMRequest(BaseModel):
# provider level
name: str | None = None
provider: str
api_key: str | None = None
api_base: str | None = None
@ -76,16 +77,19 @@ class LLMProviderUpsertRequest(LLMProvider):
# should only be used for a "custom" provider
# for default providers, the built-in model names are used
model_names: list[str] | None = None
api_key_changed: bool = False
class FullLLMProvider(LLMProvider):
class LLMProviderView(LLMProvider):
"""Stripped down representation of LLMProvider for display / limited access info only"""
id: int
is_default_provider: bool | None = None
is_default_vision_provider: bool | None = None
model_names: list[str]
@classmethod
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "FullLLMProvider":
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "LLMProviderView":
return cls(
id=llm_provider_model.id,
name=llm_provider_model.name,
@ -111,7 +115,7 @@ class FullLLMProvider(LLMProvider):
)
class VisionProviderResponse(FullLLMProvider):
class VisionProviderResponse(LLMProviderView):
"""Response model for vision providers endpoint, including vision-specific fields."""
vision_models: list[str]

View File

@ -307,6 +307,7 @@ def setup_postgres(db_session: Session) -> None:
groups=[],
display_model_names=OPEN_AI_MODEL_NAMES,
model_names=OPEN_AI_MODEL_NAMES,
api_key_changed=True,
)
new_llm_provider = upsert_llm_provider(
llm_provider=model_req, db_session=db_session

View File

@ -3,8 +3,8 @@ from uuid import uuid4
import requests
from onyx.server.manage.llm.models import FullLLMProvider
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestLLMProvider
@ -39,6 +39,7 @@ class LLMProviderManager:
groups=groups or [],
display_model_names=None,
model_names=None,
api_key_changed=True,
)
llm_response = requests.put(
@ -90,7 +91,7 @@ class LLMProviderManager:
@staticmethod
def get_all(
user_performing_action: DATestUser | None = None,
) -> list[FullLLMProvider]:
) -> list[LLMProviderView]:
response = requests.get(
f"{API_SERVER_URL}/admin/llm/provider",
headers=user_performing_action.headers
@ -98,7 +99,7 @@ class LLMProviderManager:
else GENERAL_HEADERS,
)
response.raise_for_status()
return [FullLLMProvider(**ug) for ug in response.json()]
return [LLMProviderView(**ug) for ug in response.json()]
@staticmethod
def verify(
@ -111,18 +112,19 @@ class LLMProviderManager:
if llm_provider.id == fetched_llm_provider.id:
if verify_deleted:
raise ValueError(
f"User group {llm_provider.id} found but should be deleted"
f"LLM Provider {llm_provider.id} found but should be deleted"
)
fetched_llm_groups = set(fetched_llm_provider.groups)
llm_provider_groups = set(llm_provider.groups)
# NOTE: returned api keys are sanitized and should not match
if (
fetched_llm_groups == llm_provider_groups
and llm_provider.provider == fetched_llm_provider.provider
and llm_provider.api_key == fetched_llm_provider.api_key
and llm_provider.default_model_name
== fetched_llm_provider.default_model_name
and llm_provider.is_public == fetched_llm_provider.is_public
):
return
if not verify_deleted:
raise ValueError(f"User group {llm_provider.id} not found")
raise ValueError(f"LLM Provider {llm_provider.id} not found")

View File

@ -34,6 +34,7 @@ def test_create_llm_provider_without_display_model_names(reset: None) -> None:
json={
"name": str(uuid.uuid4()),
"provider": "openai",
"api_key": "sk-000000000000000000000000000000000000000000000000",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,
@ -49,6 +50,9 @@ def test_create_llm_provider_without_display_model_names(reset: None) -> None:
assert provider_data["model_names"] == _DEFAULT_MODELS
assert provider_data["default_model_name"] == _DEFAULT_MODELS[0]
assert provider_data["display_model_names"] is None
assert (
provider_data["api_key"] == "sk-0****0000"
) # test that returned key is sanitized
def test_update_llm_provider_model_names(reset: None) -> None:
@ -64,10 +68,12 @@ def test_update_llm_provider_model_names(reset: None) -> None:
json={
"name": name,
"provider": "openai",
"api_key": "sk-000000000000000000000000000000000000000000000000",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": [_DEFAULT_MODELS[0]],
"is_public": True,
"groups": [],
"api_key_changed": True,
},
)
assert response.status_code == 200
@ -81,6 +87,7 @@ def test_update_llm_provider_model_names(reset: None) -> None:
"id": created_provider["id"],
"name": name,
"provider": created_provider["provider"],
"api_key": "sk-000000000000000000000000000000000000000000000001",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,
@ -93,6 +100,30 @@ def test_update_llm_provider_model_names(reset: None) -> None:
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
assert provider_data is not None
assert provider_data["model_names"] == _DEFAULT_MODELS
assert (
provider_data["api_key"] == "sk-0****0000"
) # test that key was NOT updated due to api_key_changed not being set
# Update with api_key_changed properly set
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
json={
"id": created_provider["id"],
"name": name,
"provider": created_provider["provider"],
"api_key": "sk-000000000000000000000000000000000000000000000001",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,
"groups": [],
"api_key_changed": True,
},
)
assert response.status_code == 200
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
assert provider_data is not None
assert provider_data["api_key"] == "sk-0****0001" # test that key was updated
def test_delete_llm_provider(reset: None) -> None:
@ -107,6 +138,7 @@ def test_delete_llm_provider(reset: None) -> None:
json={
"name": "test-provider-delete",
"provider": "openai",
"api_key": "sk-000000000000000000000000000000000000000000000000",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,

View File

@ -61,7 +61,7 @@ import {
import { buildImgUrl } from "@/app/chat/files/images/utils";
import { useAssistants } from "@/components/context/AssistantsContext";
import { debounce } from "lodash";
import { FullLLMProvider } from "../configuration/llm/interfaces";
import { LLMProviderView } from "../configuration/llm/interfaces";
import StarterMessagesList from "./StarterMessageList";
import { Switch, SwitchField } from "@/components/ui/switch";
@ -123,7 +123,7 @@ export function AssistantEditor({
documentSets: DocumentSet[];
user: User | null;
defaultPublic: boolean;
llmProviders: FullLLMProvider[];
llmProviders: LLMProviderView[];
tools: ToolSnapshot[];
shouldAddAssistantToUserPreferences?: boolean;
admin?: boolean;

View File

@ -1,4 +1,4 @@
import { FullLLMProvider } from "../configuration/llm/interfaces";
import { LLMProviderView } from "../configuration/llm/interfaces";
import { Persona, StarterMessage } from "./interfaces";
interface PersonaUpsertRequest {
@ -319,7 +319,7 @@ export function checkPersonaRequiresImageGeneration(persona: Persona) {
}
export function providersContainImageGeneratingSupport(
providers: FullLLMProvider[]
providers: LLMProviderView[]
) {
return providers.some((provider) => provider.provider === "openai");
}

View File

@ -1,5 +1,5 @@
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
import { LLMProviderView, WellKnownLLMProviderDescriptor } from "./interfaces";
import { Modal } from "@/components/Modal";
import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm";
import { CustomLLMProviderUpdateForm } from "./CustomLLMProviderUpdateForm";
@ -19,7 +19,7 @@ function LLMProviderUpdateModal({
}: {
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null | undefined;
onClose: () => void;
existingLlmProvider?: FullLLMProvider;
existingLlmProvider?: LLMProviderView;
shouldMarkAsDefault?: boolean;
setPopup?: (popup: PopupSpec) => void;
}) {
@ -61,7 +61,7 @@ function LLMProviderDisplay({
shouldMarkAsDefault,
}: {
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null | undefined;
existingLlmProvider: FullLLMProvider;
existingLlmProvider: LLMProviderView;
shouldMarkAsDefault?: boolean;
}) {
const [formIsVisible, setFormIsVisible] = useState(false);
@ -146,7 +146,7 @@ export function ConfiguredLLMProviderDisplay({
existingLlmProviders,
llmProviderDescriptors,
}: {
existingLlmProviders: FullLLMProvider[];
existingLlmProviders: LLMProviderView[];
llmProviderDescriptors: WellKnownLLMProviderDescriptor[];
}) {
existingLlmProviders = existingLlmProviders.sort((a, b) => {

View File

@ -21,7 +21,7 @@ import {
} from "@/components/admin/connectors/Field";
import { useState } from "react";
import { useSWRConfig } from "swr";
import { FullLLMProvider } from "./interfaces";
import { LLMProviderView } from "./interfaces";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import * as Yup from "yup";
import isEqual from "lodash/isEqual";
@ -43,7 +43,7 @@ export function CustomLLMProviderUpdateForm({
hideSuccess,
}: {
onClose: () => void;
existingLlmProvider?: FullLLMProvider;
existingLlmProvider?: LLMProviderView;
shouldMarkAsDefault?: boolean;
setPopup?: (popup: PopupSpec) => void;
hideSuccess?: boolean;
@ -165,7 +165,7 @@ export function CustomLLMProviderUpdateForm({
}
if (shouldMarkAsDefault) {
const newLlmProvider = (await response.json()) as FullLLMProvider;
const newLlmProvider = (await response.json()) as LLMProviderView;
const setDefaultResponse = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
{

View File

@ -9,7 +9,7 @@ import Text from "@/components/ui/text";
import Title from "@/components/ui/title";
import { Button } from "@/components/ui/button";
import { ThreeDotsLoader } from "@/components/Loading";
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
import { LLMProviderView, WellKnownLLMProviderDescriptor } from "./interfaces";
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm";
import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
@ -25,7 +25,7 @@ function LLMProviderUpdateModal({
}: {
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null;
onClose: () => void;
existingLlmProvider?: FullLLMProvider;
existingLlmProvider?: LLMProviderView;
shouldMarkAsDefault?: boolean;
setPopup?: (popup: PopupSpec) => void;
}) {
@ -99,7 +99,7 @@ function DefaultLLMProviderDisplay({
function AddCustomLLMProvider({
existingLlmProviders,
}: {
existingLlmProviders: FullLLMProvider[];
existingLlmProviders: LLMProviderView[];
}) {
const [formIsVisible, setFormIsVisible] = useState(false);
@ -130,7 +130,7 @@ export function LLMConfiguration() {
const { data: llmProviderDescriptors } = useSWR<
WellKnownLLMProviderDescriptor[]
>("/api/admin/llm/built-in/options", errorHandlingFetcher);
const { data: existingLlmProviders } = useSWR<FullLLMProvider[]>(
const { data: existingLlmProviders } = useSWR<LLMProviderView[]>(
LLM_PROVIDERS_ADMIN_URL,
errorHandlingFetcher
);

View File

@ -14,7 +14,7 @@ import {
import { useState } from "react";
import { useSWRConfig } from "swr";
import { defaultModelsByProvider, getDisplayNameForModel } from "@/lib/hooks";
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
import { LLMProviderView, WellKnownLLMProviderDescriptor } from "./interfaces";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import * as Yup from "yup";
import isEqual from "lodash/isEqual";
@ -31,7 +31,7 @@ export function LLMProviderUpdateForm({
}: {
llmProviderDescriptor: WellKnownLLMProviderDescriptor;
onClose: () => void;
existingLlmProvider?: FullLLMProvider;
existingLlmProvider?: LLMProviderView;
shouldMarkAsDefault?: boolean;
hideAdvanced?: boolean;
setPopup?: (popup: PopupSpec) => void;
@ -73,6 +73,7 @@ export function LLMProviderUpdateForm({
defaultModelsByProvider[llmProviderDescriptor.name] ||
[],
deployment_name: existingLlmProvider?.deployment_name,
api_key_changed: false,
};
// Setup validation schema if required
@ -113,6 +114,7 @@ export function LLMProviderUpdateForm({
is_public: Yup.boolean().required(),
groups: Yup.array().of(Yup.number()),
display_model_names: Yup.array().of(Yup.string()),
api_key_changed: Yup.boolean(),
});
return (
@ -122,6 +124,8 @@ export function LLMProviderUpdateForm({
onSubmit={async (values, { setSubmitting }) => {
setSubmitting(true);
values.api_key_changed = values.api_key !== initialValues.api_key;
// test the configuration
if (!isEqual(values, initialValues)) {
setIsTesting(true);
@ -180,7 +184,7 @@ export function LLMProviderUpdateForm({
}
if (shouldMarkAsDefault) {
const newLlmProvider = (await response.json()) as FullLLMProvider;
const newLlmProvider = (await response.json()) as LLMProviderView;
const setDefaultResponse = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
{

View File

@ -53,14 +53,14 @@ export interface LLMProvider {
is_default_vision_provider: boolean | null;
}
export interface FullLLMProvider extends LLMProvider {
export interface LLMProviderView extends LLMProvider {
id: number;
is_default_provider: boolean | null;
model_names: string[];
icon?: React.FC<{ size?: number; className?: string }>;
}
export interface VisionProvider extends FullLLMProvider {
export interface VisionProvider extends LLMProviderView {
vision_models: string[];
}

View File

@ -1,5 +1,5 @@
import {
FullLLMProvider,
LLMProviderView,
WellKnownLLMProviderDescriptor,
} from "@/app/admin/configuration/llm/interfaces";
import { User } from "@/lib/types";
@ -36,7 +36,7 @@ export async function checkLlmProvider(user: User | null) {
const [providerResponse, optionsResponse, defaultCheckResponse] =
await Promise.all(tasks);
let providers: FullLLMProvider[] = [];
let providers: LLMProviderView[] = [];
if (providerResponse?.ok) {
providers = await providerResponse.json();
}

View File

@ -3,7 +3,7 @@ import { CCPairBasicInfo, DocumentSet, User } from "../types";
import { getCurrentUserSS } from "../userSS";
import { fetchSS } from "../utilsSS";
import {
FullLLMProvider,
LLMProviderView,
getProviderIcon,
} from "@/app/admin/configuration/llm/interfaces";
import { ToolSnapshot } from "../tools/interfaces";
@ -16,7 +16,7 @@ export async function fetchAssistantEditorInfoSS(
{
ccPairs: CCPairBasicInfo[];
documentSets: DocumentSet[];
llmProviders: FullLLMProvider[];
llmProviders: LLMProviderView[];
user: User | null;
existingPersona: Persona | null;
tools: ToolSnapshot[];
@ -83,7 +83,7 @@ export async function fetchAssistantEditorInfoSS(
];
}
const llmProviders = (await llmProvidersResponse.json()) as FullLLMProvider[];
const llmProviders = (await llmProvidersResponse.json()) as LLMProviderView[];
if (personaId && personaResponse && !personaResponse.ok) {
return [null, `Failed to fetch Persona - ${await personaResponse.text()}`];