diff --git a/backend/ee/onyx/server/tenants/provisioning.py b/backend/ee/onyx/server/tenants/provisioning.py index 19880339d..9a1b425b7 100644 --- a/backend/ee/onyx/server/tenants/provisioning.py +++ b/backend/ee/onyx/server/tenants/provisioning.py @@ -271,6 +271,7 @@ def configure_default_api_keys(db_session: Session) -> None: fast_default_model_name="claude-3-5-sonnet-20241022", model_names=ANTHROPIC_MODEL_NAMES, display_model_names=["claude-3-5-sonnet-20241022"], + api_key_changed=True, ) try: full_provider = upsert_llm_provider(anthropic_provider, db_session) @@ -283,7 +284,7 @@ def configure_default_api_keys(db_session: Session) -> None: ) if OPENAI_DEFAULT_API_KEY: - open_provider = LLMProviderUpsertRequest( + openai_provider = LLMProviderUpsertRequest( name="OpenAI", provider=OPENAI_PROVIDER_NAME, api_key=OPENAI_DEFAULT_API_KEY, @@ -291,9 +292,10 @@ def configure_default_api_keys(db_session: Session) -> None: fast_default_model_name="gpt-4o-mini", model_names=OPEN_AI_MODEL_NAMES, display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"], + api_key_changed=True, ) try: - full_provider = upsert_llm_provider(open_provider, db_session) + full_provider = upsert_llm_provider(openai_provider, db_session) update_default_provider(full_provider.id, db_session) except Exception as e: logger.error(f"Failed to configure OpenAI provider: {e}") diff --git a/backend/onyx/db/llm.py b/backend/onyx/db/llm.py index e5b1602b7..7a70462ad 100644 --- a/backend/onyx/db/llm.py +++ b/backend/onyx/db/llm.py @@ -16,8 +16,8 @@ from onyx.db.models import User__UserGroup from onyx.llm.utils import model_supports_image_input from onyx.server.manage.embedding.models import CloudEmbeddingProvider from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest -from onyx.server.manage.llm.models import FullLLMProvider from onyx.server.manage.llm.models import LLMProviderUpsertRequest +from onyx.server.manage.llm.models import LLMProviderView from shared_configs.enums import EmbeddingProvider @@ -67,7 +67,7 @@ def upsert_cloud_embedding_provider( def upsert_llm_provider( llm_provider: LLMProviderUpsertRequest, db_session: Session, -) -> FullLLMProvider: +) -> LLMProviderView: existing_llm_provider = db_session.scalar( select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name) ) @@ -98,7 +98,7 @@ def upsert_llm_provider( group_ids=llm_provider.groups, db_session=db_session, ) - full_llm_provider = FullLLMProvider.from_model(existing_llm_provider) + full_llm_provider = LLMProviderView.from_model(existing_llm_provider) db_session.commit() @@ -132,6 +132,16 @@ def fetch_existing_llm_providers( return list(db_session.scalars(stmt).all()) +def fetch_existing_llm_provider( + provider_name: str, db_session: Session +) -> LLMProviderModel | None: + provider_model = db_session.scalar( + select(LLMProviderModel).where(LLMProviderModel.name == provider_name) + ) + + return provider_model + + def fetch_existing_llm_providers_for_user( db_session: Session, user: User | None = None, @@ -177,7 +187,7 @@ def fetch_embedding_provider( ) -def fetch_default_provider(db_session: Session) -> FullLLMProvider | None: +def fetch_default_provider(db_session: Session) -> LLMProviderView | None: provider_model = db_session.scalar( select(LLMProviderModel).where( LLMProviderModel.is_default_provider == True # noqa: E712 @@ -185,10 +195,10 @@ def fetch_default_provider(db_session: Session) -> FullLLMProvider | None: ) if not provider_model: return None - return FullLLMProvider.from_model(provider_model) + return LLMProviderView.from_model(provider_model) -def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None: +def fetch_default_vision_provider(db_session: Session) -> LLMProviderView | None: provider_model = db_session.scalar( select(LLMProviderModel).where( LLMProviderModel.is_default_vision_provider == True # noqa: E712 @@ -196,16 +206,18 @@ def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None ) if not provider_model: return None - return FullLLMProvider.from_model(provider_model) + return LLMProviderView.from_model(provider_model) -def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None: +def fetch_llm_provider_view( + db_session: Session, provider_name: str +) -> LLMProviderView | None: provider_model = db_session.scalar( select(LLMProviderModel).where(LLMProviderModel.name == provider_name) ) if not provider_model: return None - return FullLLMProvider.from_model(provider_model) + return LLMProviderView.from_model(provider_model) def remove_embedding_provider( diff --git a/backend/onyx/llm/factory.py b/backend/onyx/llm/factory.py index 3d0bb6b3b..c77518f51 100644 --- a/backend/onyx/llm/factory.py +++ b/backend/onyx/llm/factory.py @@ -9,14 +9,14 @@ from onyx.db.engine import get_session_with_current_tenant from onyx.db.llm import fetch_default_provider from onyx.db.llm import fetch_default_vision_provider from onyx.db.llm import fetch_existing_llm_providers -from onyx.db.llm import fetch_provider +from onyx.db.llm import fetch_llm_provider_view from onyx.db.models import Persona from onyx.llm.chat_llm import DefaultMultiLLM from onyx.llm.exceptions import GenAIDisabledException from onyx.llm.interfaces import LLM from onyx.llm.override_models import LLMOverride from onyx.llm.utils import model_supports_image_input -from onyx.server.manage.llm.models import FullLLMProvider +from onyx.server.manage.llm.models import LLMProviderView from onyx.utils.headers import build_llm_extra_headers from onyx.utils.logger import setup_logger from onyx.utils.long_term_log import LongTermLogger @@ -62,7 +62,7 @@ def get_llms_for_persona( ) with get_session_context_manager() as db_session: - llm_provider = fetch_provider(db_session, provider_name) + llm_provider = fetch_llm_provider_view(db_session, provider_name) if not llm_provider: raise ValueError("No LLM provider found") @@ -106,7 +106,7 @@ def get_default_llm_with_vision( if DISABLE_GENERATIVE_AI: raise GenAIDisabledException() - def create_vision_llm(provider: FullLLMProvider, model: str) -> LLM: + def create_vision_llm(provider: LLMProviderView, model: str) -> LLM: """Helper to create an LLM if the provider supports image input.""" return get_llm( provider=provider.provider, @@ -148,7 +148,7 @@ def get_default_llm_with_vision( provider.default_vision_model, provider.provider ): return create_vision_llm( - FullLLMProvider.from_model(provider), provider.default_vision_model + LLMProviderView.from_model(provider), provider.default_vision_model ) return None diff --git a/backend/onyx/server/manage/llm/api.py b/backend/onyx/server/manage/llm/api.py index ceafca2e3..0a5ceb036 100644 --- a/backend/onyx/server/manage/llm/api.py +++ b/backend/onyx/server/manage/llm/api.py @@ -9,9 +9,9 @@ from sqlalchemy.orm import Session from onyx.auth.users import current_admin_user from onyx.auth.users import current_chat_accessible_user from onyx.db.engine import get_session +from onyx.db.llm import fetch_existing_llm_provider from onyx.db.llm import fetch_existing_llm_providers from onyx.db.llm import fetch_existing_llm_providers_for_user -from onyx.db.llm import fetch_provider from onyx.db.llm import remove_llm_provider from onyx.db.llm import update_default_provider from onyx.db.llm import update_default_vision_provider @@ -24,9 +24,9 @@ from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor from onyx.llm.utils import litellm_exception_to_error_msg from onyx.llm.utils import model_supports_image_input from onyx.llm.utils import test_llm -from onyx.server.manage.llm.models import FullLLMProvider from onyx.server.manage.llm.models import LLMProviderDescriptor from onyx.server.manage.llm.models import LLMProviderUpsertRequest +from onyx.server.manage.llm.models import LLMProviderView from onyx.server.manage.llm.models import TestLLMRequest from onyx.server.manage.llm.models import VisionProviderResponse from onyx.utils.logger import setup_logger @@ -49,11 +49,27 @@ def fetch_llm_options( def test_llm_configuration( test_llm_request: TestLLMRequest, _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), ) -> None: + """Test regular llm and fast llm settings""" + + # the api key is sanitized if we are testing a provider already in the system + + test_api_key = test_llm_request.api_key + if test_llm_request.name: + # NOTE: we are querying by name. we probably should be querying by an invariant id, but + # as it turns out the name is not editable in the UI and other code also keys off name, + # so we won't rock the boat just yet. + existing_provider = fetch_existing_llm_provider( + test_llm_request.name, db_session + ) + if existing_provider: + test_api_key = existing_provider.api_key + llm = get_llm( provider=test_llm_request.provider, model=test_llm_request.default_model_name, - api_key=test_llm_request.api_key, + api_key=test_api_key, api_base=test_llm_request.api_base, api_version=test_llm_request.api_version, custom_config=test_llm_request.custom_config, @@ -69,7 +85,7 @@ def test_llm_configuration( fast_llm = get_llm( provider=test_llm_request.provider, model=test_llm_request.fast_default_model_name, - api_key=test_llm_request.api_key, + api_key=test_api_key, api_base=test_llm_request.api_base, api_version=test_llm_request.api_version, custom_config=test_llm_request.custom_config, @@ -119,11 +135,17 @@ def test_default_provider( def list_llm_providers( _: User | None = Depends(current_admin_user), db_session: Session = Depends(get_session), -) -> list[FullLLMProvider]: - return [ - FullLLMProvider.from_model(llm_provider_model) - for llm_provider_model in fetch_existing_llm_providers(db_session) - ] +) -> list[LLMProviderView]: + llm_provider_list: list[LLMProviderView] = [] + for llm_provider_model in fetch_existing_llm_providers(db_session): + full_llm_provider = LLMProviderView.from_model(llm_provider_model) + if full_llm_provider.api_key: + full_llm_provider.api_key = ( + full_llm_provider.api_key[:4] + "****" + full_llm_provider.api_key[-4:] + ) + llm_provider_list.append(full_llm_provider) + + return llm_provider_list @admin_router.put("/provider") @@ -135,11 +157,11 @@ def put_llm_provider( ), _: User | None = Depends(current_admin_user), db_session: Session = Depends(get_session), -) -> FullLLMProvider: +) -> LLMProviderView: # validate request (e.g. if we're intending to create but the name already exists we should throw an error) # NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache # the result - existing_provider = fetch_provider(db_session, llm_provider.name) + existing_provider = fetch_existing_llm_provider(llm_provider.name, db_session) if existing_provider and is_creation: raise HTTPException( status_code=400, @@ -161,6 +183,11 @@ def put_llm_provider( llm_provider.fast_default_model_name ) + # the llm api key is sanitized when returned to clients, so the only time we + # should get a real key is when it is explicitly changed + if existing_provider and not llm_provider.api_key_changed: + llm_provider.api_key = existing_provider.api_key + try: return upsert_llm_provider( llm_provider=llm_provider, @@ -234,7 +261,7 @@ def get_vision_capable_providers( # Only include providers with at least one vision-capable model if vision_models: - provider_dict = FullLLMProvider.from_model(provider).model_dump() + provider_dict = LLMProviderView.from_model(provider).model_dump() provider_dict["vision_models"] = vision_models logger.info( f"Vision provider: {provider.provider} with models: {vision_models}" diff --git a/backend/onyx/server/manage/llm/models.py b/backend/onyx/server/manage/llm/models.py index 3172f5adf..9d5544d96 100644 --- a/backend/onyx/server/manage/llm/models.py +++ b/backend/onyx/server/manage/llm/models.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: class TestLLMRequest(BaseModel): # provider level + name: str | None = None provider: str api_key: str | None = None api_base: str | None = None @@ -76,16 +77,19 @@ class LLMProviderUpsertRequest(LLMProvider): # should only be used for a "custom" provider # for default providers, the built-in model names are used model_names: list[str] | None = None + api_key_changed: bool = False -class FullLLMProvider(LLMProvider): +class LLMProviderView(LLMProvider): + """Stripped down representation of LLMProvider for display / limited access info only""" + id: int is_default_provider: bool | None = None is_default_vision_provider: bool | None = None model_names: list[str] @classmethod - def from_model(cls, llm_provider_model: "LLMProviderModel") -> "FullLLMProvider": + def from_model(cls, llm_provider_model: "LLMProviderModel") -> "LLMProviderView": return cls( id=llm_provider_model.id, name=llm_provider_model.name, @@ -111,7 +115,7 @@ class FullLLMProvider(LLMProvider): ) -class VisionProviderResponse(FullLLMProvider): +class VisionProviderResponse(LLMProviderView): """Response model for vision providers endpoint, including vision-specific fields.""" vision_models: list[str] diff --git a/backend/onyx/setup.py b/backend/onyx/setup.py index 1dff601ef..750b35d8d 100644 --- a/backend/onyx/setup.py +++ b/backend/onyx/setup.py @@ -307,6 +307,7 @@ def setup_postgres(db_session: Session) -> None: groups=[], display_model_names=OPEN_AI_MODEL_NAMES, model_names=OPEN_AI_MODEL_NAMES, + api_key_changed=True, ) new_llm_provider = upsert_llm_provider( llm_provider=model_req, db_session=db_session diff --git a/backend/tests/integration/common_utils/managers/llm_provider.py b/backend/tests/integration/common_utils/managers/llm_provider.py index 44d4ce501..33d29e42e 100644 --- a/backend/tests/integration/common_utils/managers/llm_provider.py +++ b/backend/tests/integration/common_utils/managers/llm_provider.py @@ -3,8 +3,8 @@ from uuid import uuid4 import requests -from onyx.server.manage.llm.models import FullLLMProvider from onyx.server.manage.llm.models import LLMProviderUpsertRequest +from onyx.server.manage.llm.models import LLMProviderView from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.constants import GENERAL_HEADERS from tests.integration.common_utils.test_models import DATestLLMProvider @@ -39,6 +39,7 @@ class LLMProviderManager: groups=groups or [], display_model_names=None, model_names=None, + api_key_changed=True, ) llm_response = requests.put( @@ -90,7 +91,7 @@ class LLMProviderManager: @staticmethod def get_all( user_performing_action: DATestUser | None = None, - ) -> list[FullLLMProvider]: + ) -> list[LLMProviderView]: response = requests.get( f"{API_SERVER_URL}/admin/llm/provider", headers=user_performing_action.headers @@ -98,7 +99,7 @@ class LLMProviderManager: else GENERAL_HEADERS, ) response.raise_for_status() - return [FullLLMProvider(**ug) for ug in response.json()] + return [LLMProviderView(**ug) for ug in response.json()] @staticmethod def verify( @@ -111,18 +112,19 @@ class LLMProviderManager: if llm_provider.id == fetched_llm_provider.id: if verify_deleted: raise ValueError( - f"User group {llm_provider.id} found but should be deleted" + f"LLM Provider {llm_provider.id} found but should be deleted" ) fetched_llm_groups = set(fetched_llm_provider.groups) llm_provider_groups = set(llm_provider.groups) + + # NOTE: returned api keys are sanitized and should not match if ( fetched_llm_groups == llm_provider_groups and llm_provider.provider == fetched_llm_provider.provider - and llm_provider.api_key == fetched_llm_provider.api_key and llm_provider.default_model_name == fetched_llm_provider.default_model_name and llm_provider.is_public == fetched_llm_provider.is_public ): return if not verify_deleted: - raise ValueError(f"User group {llm_provider.id} not found") + raise ValueError(f"LLM Provider {llm_provider.id} not found") diff --git a/backend/tests/integration/tests/llm_provider/test_llm_provider.py b/backend/tests/integration/tests/llm_provider/test_llm_provider.py index 1b7d4207e..7c72382f1 100644 --- a/backend/tests/integration/tests/llm_provider/test_llm_provider.py +++ b/backend/tests/integration/tests/llm_provider/test_llm_provider.py @@ -34,6 +34,7 @@ def test_create_llm_provider_without_display_model_names(reset: None) -> None: json={ "name": str(uuid.uuid4()), "provider": "openai", + "api_key": "sk-000000000000000000000000000000000000000000000000", "default_model_name": _DEFAULT_MODELS[0], "model_names": _DEFAULT_MODELS, "is_public": True, @@ -49,6 +50,9 @@ def test_create_llm_provider_without_display_model_names(reset: None) -> None: assert provider_data["model_names"] == _DEFAULT_MODELS assert provider_data["default_model_name"] == _DEFAULT_MODELS[0] assert provider_data["display_model_names"] is None + assert ( + provider_data["api_key"] == "sk-0****0000" + ) # test that returned key is sanitized def test_update_llm_provider_model_names(reset: None) -> None: @@ -64,10 +68,12 @@ def test_update_llm_provider_model_names(reset: None) -> None: json={ "name": name, "provider": "openai", + "api_key": "sk-000000000000000000000000000000000000000000000000", "default_model_name": _DEFAULT_MODELS[0], "model_names": [_DEFAULT_MODELS[0]], "is_public": True, "groups": [], + "api_key_changed": True, }, ) assert response.status_code == 200 @@ -81,6 +87,7 @@ def test_update_llm_provider_model_names(reset: None) -> None: "id": created_provider["id"], "name": name, "provider": created_provider["provider"], + "api_key": "sk-000000000000000000000000000000000000000000000001", "default_model_name": _DEFAULT_MODELS[0], "model_names": _DEFAULT_MODELS, "is_public": True, @@ -93,6 +100,30 @@ def test_update_llm_provider_model_names(reset: None) -> None: provider_data = _get_provider_by_id(admin_user, created_provider["id"]) assert provider_data is not None assert provider_data["model_names"] == _DEFAULT_MODELS + assert ( + provider_data["api_key"] == "sk-0****0000" + ) # test that key was NOT updated due to api_key_changed not being set + + # Update with api_key_changed properly set + response = requests.put( + f"{API_SERVER_URL}/admin/llm/provider", + headers=admin_user.headers, + json={ + "id": created_provider["id"], + "name": name, + "provider": created_provider["provider"], + "api_key": "sk-000000000000000000000000000000000000000000000001", + "default_model_name": _DEFAULT_MODELS[0], + "model_names": _DEFAULT_MODELS, + "is_public": True, + "groups": [], + "api_key_changed": True, + }, + ) + assert response.status_code == 200 + provider_data = _get_provider_by_id(admin_user, created_provider["id"]) + assert provider_data is not None + assert provider_data["api_key"] == "sk-0****0001" # test that key was updated def test_delete_llm_provider(reset: None) -> None: @@ -107,6 +138,7 @@ def test_delete_llm_provider(reset: None) -> None: json={ "name": "test-provider-delete", "provider": "openai", + "api_key": "sk-000000000000000000000000000000000000000000000000", "default_model_name": _DEFAULT_MODELS[0], "model_names": _DEFAULT_MODELS, "is_public": True, diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index 4ed8c3be9..7137102f5 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -61,7 +61,7 @@ import { import { buildImgUrl } from "@/app/chat/files/images/utils"; import { useAssistants } from "@/components/context/AssistantsContext"; import { debounce } from "lodash"; -import { FullLLMProvider } from "../configuration/llm/interfaces"; +import { LLMProviderView } from "../configuration/llm/interfaces"; import StarterMessagesList from "./StarterMessageList"; import { Switch, SwitchField } from "@/components/ui/switch"; @@ -123,7 +123,7 @@ export function AssistantEditor({ documentSets: DocumentSet[]; user: User | null; defaultPublic: boolean; - llmProviders: FullLLMProvider[]; + llmProviders: LLMProviderView[]; tools: ToolSnapshot[]; shouldAddAssistantToUserPreferences?: boolean; admin?: boolean; diff --git a/web/src/app/admin/assistants/lib.ts b/web/src/app/admin/assistants/lib.ts index a6494782f..70dc8035b 100644 --- a/web/src/app/admin/assistants/lib.ts +++ b/web/src/app/admin/assistants/lib.ts @@ -1,4 +1,4 @@ -import { FullLLMProvider } from "../configuration/llm/interfaces"; +import { LLMProviderView } from "../configuration/llm/interfaces"; import { Persona, StarterMessage } from "./interfaces"; interface PersonaUpsertRequest { @@ -319,7 +319,7 @@ export function checkPersonaRequiresImageGeneration(persona: Persona) { } export function providersContainImageGeneratingSupport( - providers: FullLLMProvider[] + providers: LLMProviderView[] ) { return providers.some((provider) => provider.provider === "openai"); } diff --git a/web/src/app/admin/configuration/llm/ConfiguredLLMProviderDisplay.tsx b/web/src/app/admin/configuration/llm/ConfiguredLLMProviderDisplay.tsx index 3146d1da7..16b3e7863 100644 --- a/web/src/app/admin/configuration/llm/ConfiguredLLMProviderDisplay.tsx +++ b/web/src/app/admin/configuration/llm/ConfiguredLLMProviderDisplay.tsx @@ -1,5 +1,5 @@ import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup"; -import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces"; +import { LLMProviderView, WellKnownLLMProviderDescriptor } from "./interfaces"; import { Modal } from "@/components/Modal"; import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm"; import { CustomLLMProviderUpdateForm } from "./CustomLLMProviderUpdateForm"; @@ -19,7 +19,7 @@ function LLMProviderUpdateModal({ }: { llmProviderDescriptor: WellKnownLLMProviderDescriptor | null | undefined; onClose: () => void; - existingLlmProvider?: FullLLMProvider; + existingLlmProvider?: LLMProviderView; shouldMarkAsDefault?: boolean; setPopup?: (popup: PopupSpec) => void; }) { @@ -61,7 +61,7 @@ function LLMProviderDisplay({ shouldMarkAsDefault, }: { llmProviderDescriptor: WellKnownLLMProviderDescriptor | null | undefined; - existingLlmProvider: FullLLMProvider; + existingLlmProvider: LLMProviderView; shouldMarkAsDefault?: boolean; }) { const [formIsVisible, setFormIsVisible] = useState(false); @@ -146,7 +146,7 @@ export function ConfiguredLLMProviderDisplay({ existingLlmProviders, llmProviderDescriptors, }: { - existingLlmProviders: FullLLMProvider[]; + existingLlmProviders: LLMProviderView[]; llmProviderDescriptors: WellKnownLLMProviderDescriptor[]; }) { existingLlmProviders = existingLlmProviders.sort((a, b) => { diff --git a/web/src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.tsx b/web/src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.tsx index 0b175554d..1bdef47e4 100644 --- a/web/src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.tsx +++ b/web/src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.tsx @@ -21,7 +21,7 @@ import { } from "@/components/admin/connectors/Field"; import { useState } from "react"; import { useSWRConfig } from "swr"; -import { FullLLMProvider } from "./interfaces"; +import { LLMProviderView } from "./interfaces"; import { PopupSpec } from "@/components/admin/connectors/Popup"; import * as Yup from "yup"; import isEqual from "lodash/isEqual"; @@ -43,7 +43,7 @@ export function CustomLLMProviderUpdateForm({ hideSuccess, }: { onClose: () => void; - existingLlmProvider?: FullLLMProvider; + existingLlmProvider?: LLMProviderView; shouldMarkAsDefault?: boolean; setPopup?: (popup: PopupSpec) => void; hideSuccess?: boolean; @@ -165,7 +165,7 @@ export function CustomLLMProviderUpdateForm({ } if (shouldMarkAsDefault) { - const newLlmProvider = (await response.json()) as FullLLMProvider; + const newLlmProvider = (await response.json()) as LLMProviderView; const setDefaultResponse = await fetch( `${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`, { diff --git a/web/src/app/admin/configuration/llm/LLMConfiguration.tsx b/web/src/app/admin/configuration/llm/LLMConfiguration.tsx index 933efa597..56645df4f 100644 --- a/web/src/app/admin/configuration/llm/LLMConfiguration.tsx +++ b/web/src/app/admin/configuration/llm/LLMConfiguration.tsx @@ -9,7 +9,7 @@ import Text from "@/components/ui/text"; import Title from "@/components/ui/title"; import { Button } from "@/components/ui/button"; import { ThreeDotsLoader } from "@/components/Loading"; -import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces"; +import { LLMProviderView, WellKnownLLMProviderDescriptor } from "./interfaces"; import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup"; import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm"; import { LLM_PROVIDERS_ADMIN_URL } from "./constants"; @@ -25,7 +25,7 @@ function LLMProviderUpdateModal({ }: { llmProviderDescriptor: WellKnownLLMProviderDescriptor | null; onClose: () => void; - existingLlmProvider?: FullLLMProvider; + existingLlmProvider?: LLMProviderView; shouldMarkAsDefault?: boolean; setPopup?: (popup: PopupSpec) => void; }) { @@ -99,7 +99,7 @@ function DefaultLLMProviderDisplay({ function AddCustomLLMProvider({ existingLlmProviders, }: { - existingLlmProviders: FullLLMProvider[]; + existingLlmProviders: LLMProviderView[]; }) { const [formIsVisible, setFormIsVisible] = useState(false); @@ -130,7 +130,7 @@ export function LLMConfiguration() { const { data: llmProviderDescriptors } = useSWR< WellKnownLLMProviderDescriptor[] >("/api/admin/llm/built-in/options", errorHandlingFetcher); - const { data: existingLlmProviders } = useSWR( + const { data: existingLlmProviders } = useSWR( LLM_PROVIDERS_ADMIN_URL, errorHandlingFetcher ); diff --git a/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx b/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx index f8cb5b6cb..cb2881a31 100644 --- a/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx +++ b/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx @@ -14,7 +14,7 @@ import { import { useState } from "react"; import { useSWRConfig } from "swr"; import { defaultModelsByProvider, getDisplayNameForModel } from "@/lib/hooks"; -import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces"; +import { LLMProviderView, WellKnownLLMProviderDescriptor } from "./interfaces"; import { PopupSpec } from "@/components/admin/connectors/Popup"; import * as Yup from "yup"; import isEqual from "lodash/isEqual"; @@ -31,7 +31,7 @@ export function LLMProviderUpdateForm({ }: { llmProviderDescriptor: WellKnownLLMProviderDescriptor; onClose: () => void; - existingLlmProvider?: FullLLMProvider; + existingLlmProvider?: LLMProviderView; shouldMarkAsDefault?: boolean; hideAdvanced?: boolean; setPopup?: (popup: PopupSpec) => void; @@ -73,6 +73,7 @@ export function LLMProviderUpdateForm({ defaultModelsByProvider[llmProviderDescriptor.name] || [], deployment_name: existingLlmProvider?.deployment_name, + api_key_changed: false, }; // Setup validation schema if required @@ -113,6 +114,7 @@ export function LLMProviderUpdateForm({ is_public: Yup.boolean().required(), groups: Yup.array().of(Yup.number()), display_model_names: Yup.array().of(Yup.string()), + api_key_changed: Yup.boolean(), }); return ( @@ -122,6 +124,8 @@ export function LLMProviderUpdateForm({ onSubmit={async (values, { setSubmitting }) => { setSubmitting(true); + values.api_key_changed = values.api_key !== initialValues.api_key; + // test the configuration if (!isEqual(values, initialValues)) { setIsTesting(true); @@ -180,7 +184,7 @@ export function LLMProviderUpdateForm({ } if (shouldMarkAsDefault) { - const newLlmProvider = (await response.json()) as FullLLMProvider; + const newLlmProvider = (await response.json()) as LLMProviderView; const setDefaultResponse = await fetch( `${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`, { diff --git a/web/src/app/admin/configuration/llm/interfaces.ts b/web/src/app/admin/configuration/llm/interfaces.ts index 641a372d2..80971e0cc 100644 --- a/web/src/app/admin/configuration/llm/interfaces.ts +++ b/web/src/app/admin/configuration/llm/interfaces.ts @@ -53,14 +53,14 @@ export interface LLMProvider { is_default_vision_provider: boolean | null; } -export interface FullLLMProvider extends LLMProvider { +export interface LLMProviderView extends LLMProvider { id: number; is_default_provider: boolean | null; model_names: string[]; icon?: React.FC<{ size?: number; className?: string }>; } -export interface VisionProvider extends FullLLMProvider { +export interface VisionProvider extends LLMProviderView { vision_models: string[]; } diff --git a/web/src/components/initialSetup/welcome/lib.ts b/web/src/components/initialSetup/welcome/lib.ts index 5cbe54cc3..822b9f1ea 100644 --- a/web/src/components/initialSetup/welcome/lib.ts +++ b/web/src/components/initialSetup/welcome/lib.ts @@ -1,5 +1,5 @@ import { - FullLLMProvider, + LLMProviderView, WellKnownLLMProviderDescriptor, } from "@/app/admin/configuration/llm/interfaces"; import { User } from "@/lib/types"; @@ -36,7 +36,7 @@ export async function checkLlmProvider(user: User | null) { const [providerResponse, optionsResponse, defaultCheckResponse] = await Promise.all(tasks); - let providers: FullLLMProvider[] = []; + let providers: LLMProviderView[] = []; if (providerResponse?.ok) { providers = await providerResponse.json(); } diff --git a/web/src/lib/assistants/fetchPersonaEditorInfoSS.ts b/web/src/lib/assistants/fetchPersonaEditorInfoSS.ts index 2c00e0dfc..3d4a85845 100644 --- a/web/src/lib/assistants/fetchPersonaEditorInfoSS.ts +++ b/web/src/lib/assistants/fetchPersonaEditorInfoSS.ts @@ -3,7 +3,7 @@ import { CCPairBasicInfo, DocumentSet, User } from "../types"; import { getCurrentUserSS } from "../userSS"; import { fetchSS } from "../utilsSS"; import { - FullLLMProvider, + LLMProviderView, getProviderIcon, } from "@/app/admin/configuration/llm/interfaces"; import { ToolSnapshot } from "../tools/interfaces"; @@ -16,7 +16,7 @@ export async function fetchAssistantEditorInfoSS( { ccPairs: CCPairBasicInfo[]; documentSets: DocumentSet[]; - llmProviders: FullLLMProvider[]; + llmProviders: LLMProviderView[]; user: User | null; existingPersona: Persona | null; tools: ToolSnapshot[]; @@ -83,7 +83,7 @@ export async function fetchAssistantEditorInfoSS( ]; } - const llmProviders = (await llmProvidersResponse.json()) as FullLLMProvider[]; + const llmProviders = (await llmProvidersResponse.json()) as LLMProviderView[]; if (personaId && personaResponse && !personaResponse.ok) { return [null, `Failed to fetch Persona - ${await personaResponse.text()}`];