From 660021bf29c1e315e2dee50d22ea0ddc53dfa572 Mon Sep 17 00:00:00 2001 From: "Richard Kuo (Danswer)" Date: Wed, 12 Mar 2025 15:19:59 -0700 Subject: [PATCH] sanitize llm keys and handle updates properly --- .../ee/onyx/server/tenants/provisioning.py | 6 ++-- backend/onyx/db/llm.py | 26 +++++++++++----- backend/onyx/llm/factory.py | 4 +-- backend/onyx/server/manage/llm/api.py | 29 ++++++++++++------ backend/onyx/server/manage/llm/models.py | 7 +++-- backend/onyx/setup.py | 1 + .../common_utils/managers/llm_provider.py | 7 +++-- .../tests/llm_provider/test_llm_provider.py | 30 +++++++++++++++++++ .../llm/LLMProviderUpdateForm.tsx | 4 +++ 9 files changed, 89 insertions(+), 25 deletions(-) diff --git a/backend/ee/onyx/server/tenants/provisioning.py b/backend/ee/onyx/server/tenants/provisioning.py index d3e523314..332710e14 100644 --- a/backend/ee/onyx/server/tenants/provisioning.py +++ b/backend/ee/onyx/server/tenants/provisioning.py @@ -215,6 +215,7 @@ def configure_default_api_keys(db_session: Session) -> None: fast_default_model_name="claude-3-5-sonnet-20241022", model_names=ANTHROPIC_MODEL_NAMES, display_model_names=["claude-3-5-sonnet-20241022"], + api_key_changed=True, ) try: full_provider = upsert_llm_provider(anthropic_provider, db_session) @@ -227,7 +228,7 @@ def configure_default_api_keys(db_session: Session) -> None: ) if OPENAI_DEFAULT_API_KEY: - open_provider = LLMProviderUpsertRequest( + openai_provider = LLMProviderUpsertRequest( name="OpenAI", provider=OPENAI_PROVIDER_NAME, api_key=OPENAI_DEFAULT_API_KEY, @@ -235,9 +236,10 @@ def configure_default_api_keys(db_session: Session) -> None: fast_default_model_name="gpt-4o-mini", model_names=OPEN_AI_MODEL_NAMES, display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"], + api_key_changed=True, ) try: - full_provider = upsert_llm_provider(open_provider, db_session) + full_provider = upsert_llm_provider(openai_provider, db_session) update_default_provider(full_provider.id, db_session) except Exception as e: logger.error(f"Failed to configure OpenAI provider: {e}") diff --git a/backend/onyx/db/llm.py b/backend/onyx/db/llm.py index eff919295..8d3f50944 100644 --- a/backend/onyx/db/llm.py +++ b/backend/onyx/db/llm.py @@ -15,8 +15,8 @@ from onyx.db.models import User from onyx.db.models import User__UserGroup from onyx.server.manage.embedding.models import CloudEmbeddingProvider from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest -from onyx.server.manage.llm.models import FullLLMProvider from onyx.server.manage.llm.models import LLMProviderUpsertRequest +from onyx.server.manage.llm.models import LLMProviderView from shared_configs.enums import EmbeddingProvider @@ -66,7 +66,7 @@ def upsert_cloud_embedding_provider( def upsert_llm_provider( llm_provider: LLMProviderUpsertRequest, db_session: Session, -) -> FullLLMProvider: +) -> LLMProviderView: existing_llm_provider = db_session.scalar( select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name) ) @@ -97,7 +97,7 @@ def upsert_llm_provider( group_ids=llm_provider.groups, db_session=db_session, ) - full_llm_provider = FullLLMProvider.from_model(existing_llm_provider) + full_llm_provider = LLMProviderView.from_model(existing_llm_provider) db_session.commit() @@ -131,6 +131,16 @@ def fetch_existing_llm_providers( return list(db_session.scalars(stmt).all()) +def fetch_existing_llm_provider( + provider_name: str, db_session: Session +) -> LLMProviderModel | None: + provider_model = db_session.scalar( + select(LLMProviderModel).where(LLMProviderModel.name == provider_name) + ) + + return provider_model + + def fetch_existing_llm_providers_for_user( db_session: Session, user: User | None = None, @@ -176,7 +186,7 @@ def fetch_embedding_provider( ) -def fetch_default_provider(db_session: Session) -> FullLLMProvider | None: +def fetch_default_provider(db_session: Session) -> LLMProviderView | None: provider_model = db_session.scalar( select(LLMProviderModel).where( LLMProviderModel.is_default_provider == True # noqa: E712 @@ -184,16 +194,18 @@ def fetch_default_provider(db_session: Session) -> FullLLMProvider | None: ) if not provider_model: return None - return FullLLMProvider.from_model(provider_model) + return LLMProviderView.from_model(provider_model) -def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None: +def fetch_llm_provider_view( + db_session: Session, provider_name: str +) -> LLMProviderView | None: provider_model = db_session.scalar( select(LLMProviderModel).where(LLMProviderModel.name == provider_name) ) if not provider_model: return None - return FullLLMProvider.from_model(provider_model) + return LLMProviderView.from_model(provider_model) def remove_embedding_provider( diff --git a/backend/onyx/llm/factory.py b/backend/onyx/llm/factory.py index 4c8a5f093..314b681b4 100644 --- a/backend/onyx/llm/factory.py +++ b/backend/onyx/llm/factory.py @@ -7,7 +7,7 @@ from onyx.configs.model_configs import GEN_AI_TEMPERATURE from onyx.db.engine import get_session_context_manager from onyx.db.llm import fetch_default_provider from onyx.db.llm import fetch_existing_llm_providers -from onyx.db.llm import fetch_provider +from onyx.db.llm import fetch_llm_provider_view from onyx.db.models import Persona from onyx.llm.chat_llm import DefaultMultiLLM from onyx.llm.exceptions import GenAIDisabledException @@ -59,7 +59,7 @@ def get_llms_for_persona( ) with get_session_context_manager() as db_session: - llm_provider = fetch_provider(db_session, provider_name) + llm_provider = fetch_llm_provider_view(db_session, provider_name) if not llm_provider: raise ValueError("No LLM provider found") diff --git a/backend/onyx/server/manage/llm/api.py b/backend/onyx/server/manage/llm/api.py index 7a76ed196..1120c3762 100644 --- a/backend/onyx/server/manage/llm/api.py +++ b/backend/onyx/server/manage/llm/api.py @@ -9,9 +9,9 @@ from sqlalchemy.orm import Session from onyx.auth.users import current_admin_user from onyx.auth.users import current_chat_accessible_user from onyx.db.engine import get_session +from onyx.db.llm import fetch_existing_llm_provider from onyx.db.llm import fetch_existing_llm_providers from onyx.db.llm import fetch_existing_llm_providers_for_user -from onyx.db.llm import fetch_provider from onyx.db.llm import remove_llm_provider from onyx.db.llm import update_default_provider from onyx.db.llm import upsert_llm_provider @@ -22,9 +22,9 @@ from onyx.llm.llm_provider_options import fetch_available_well_known_llms from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor from onyx.llm.utils import litellm_exception_to_error_msg from onyx.llm.utils import test_llm -from onyx.server.manage.llm.models import FullLLMProvider from onyx.server.manage.llm.models import LLMProviderDescriptor from onyx.server.manage.llm.models import LLMProviderUpsertRequest +from onyx.server.manage.llm.models import LLMProviderView from onyx.server.manage.llm.models import TestLLMRequest from onyx.utils.logger import setup_logger from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel @@ -116,11 +116,17 @@ def test_default_provider( def list_llm_providers( _: User | None = Depends(current_admin_user), db_session: Session = Depends(get_session), -) -> list[FullLLMProvider]: - return [ - FullLLMProvider.from_model(llm_provider_model) - for llm_provider_model in fetch_existing_llm_providers(db_session) - ] +) -> list[LLMProviderView]: + llm_provider_list: list[LLMProviderView] = [] + for llm_provider_model in fetch_existing_llm_providers(db_session): + full_llm_provider = LLMProviderView.from_model(llm_provider_model) + if full_llm_provider.api_key: + full_llm_provider.api_key = ( + full_llm_provider.api_key[:4] + "****" + full_llm_provider.api_key[-4:] + ) + llm_provider_list.append(full_llm_provider) + + return llm_provider_list @admin_router.put("/provider") @@ -132,11 +138,11 @@ def put_llm_provider( ), _: User | None = Depends(current_admin_user), db_session: Session = Depends(get_session), -) -> FullLLMProvider: +) -> LLMProviderView: # validate request (e.g. if we're intending to create but the name already exists we should throw an error) # NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache # the result - existing_provider = fetch_provider(db_session, llm_provider.name) + existing_provider = fetch_existing_llm_provider(llm_provider.name, db_session) if existing_provider and is_creation: raise HTTPException( status_code=400, @@ -158,6 +164,11 @@ def put_llm_provider( llm_provider.fast_default_model_name ) + # the llm api key is sanitized when returned to clients, so the only time we + # should get a real key is when it is explicitly changed + if existing_provider and not llm_provider.api_key_changed: + llm_provider.api_key = existing_provider.api_key + try: return upsert_llm_provider( llm_provider=llm_provider, diff --git a/backend/onyx/server/manage/llm/models.py b/backend/onyx/server/manage/llm/models.py index 91c59fb15..ace49052a 100644 --- a/backend/onyx/server/manage/llm/models.py +++ b/backend/onyx/server/manage/llm/models.py @@ -74,15 +74,18 @@ class LLMProviderUpsertRequest(LLMProvider): # should only be used for a "custom" provider # for default providers, the built-in model names are used model_names: list[str] | None = None + api_key_changed: bool = False -class FullLLMProvider(LLMProvider): +class LLMProviderView(LLMProvider): + """Stripped down representation of LLMProvider for display / limited access info only""" + id: int is_default_provider: bool | None = None model_names: list[str] @classmethod - def from_model(cls, llm_provider_model: "LLMProviderModel") -> "FullLLMProvider": + def from_model(cls, llm_provider_model: "LLMProviderModel") -> "LLMProviderView": return cls( id=llm_provider_model.id, name=llm_provider_model.name, diff --git a/backend/onyx/setup.py b/backend/onyx/setup.py index 1dff601ef..750b35d8d 100644 --- a/backend/onyx/setup.py +++ b/backend/onyx/setup.py @@ -307,6 +307,7 @@ def setup_postgres(db_session: Session) -> None: groups=[], display_model_names=OPEN_AI_MODEL_NAMES, model_names=OPEN_AI_MODEL_NAMES, + api_key_changed=True, ) new_llm_provider = upsert_llm_provider( llm_provider=model_req, db_session=db_session diff --git a/backend/tests/integration/common_utils/managers/llm_provider.py b/backend/tests/integration/common_utils/managers/llm_provider.py index 44d4ce501..f41874956 100644 --- a/backend/tests/integration/common_utils/managers/llm_provider.py +++ b/backend/tests/integration/common_utils/managers/llm_provider.py @@ -3,8 +3,8 @@ from uuid import uuid4 import requests -from onyx.server.manage.llm.models import FullLLMProvider from onyx.server.manage.llm.models import LLMProviderUpsertRequest +from onyx.server.manage.llm.models import LLMProviderView from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.constants import GENERAL_HEADERS from tests.integration.common_utils.test_models import DATestLLMProvider @@ -39,6 +39,7 @@ class LLMProviderManager: groups=groups or [], display_model_names=None, model_names=None, + api_key_changed=True, ) llm_response = requests.put( @@ -90,7 +91,7 @@ class LLMProviderManager: @staticmethod def get_all( user_performing_action: DATestUser | None = None, - ) -> list[FullLLMProvider]: + ) -> list[LLMProviderView]: response = requests.get( f"{API_SERVER_URL}/admin/llm/provider", headers=user_performing_action.headers @@ -98,7 +99,7 @@ class LLMProviderManager: else GENERAL_HEADERS, ) response.raise_for_status() - return [FullLLMProvider(**ug) for ug in response.json()] + return [LLMProviderView(**ug) for ug in response.json()] @staticmethod def verify( diff --git a/backend/tests/integration/tests/llm_provider/test_llm_provider.py b/backend/tests/integration/tests/llm_provider/test_llm_provider.py index 1b7d4207e..df72041e3 100644 --- a/backend/tests/integration/tests/llm_provider/test_llm_provider.py +++ b/backend/tests/integration/tests/llm_provider/test_llm_provider.py @@ -34,6 +34,7 @@ def test_create_llm_provider_without_display_model_names(reset: None) -> None: json={ "name": str(uuid.uuid4()), "provider": "openai", + "api_key": "sk-000000000000000000000000000000000000000000000000", "default_model_name": _DEFAULT_MODELS[0], "model_names": _DEFAULT_MODELS, "is_public": True, @@ -49,6 +50,9 @@ def test_create_llm_provider_without_display_model_names(reset: None) -> None: assert provider_data["model_names"] == _DEFAULT_MODELS assert provider_data["default_model_name"] == _DEFAULT_MODELS[0] assert provider_data["display_model_names"] is None + assert ( + provider_data["api_key"] == "sk-0****0000" + ) # test that returned key is sanitized def test_update_llm_provider_model_names(reset: None) -> None: @@ -64,10 +68,12 @@ def test_update_llm_provider_model_names(reset: None) -> None: json={ "name": name, "provider": "openai", + "api_key": "sk-000000000000000000000000000000000000000000000000", "default_model_name": _DEFAULT_MODELS[0], "model_names": [_DEFAULT_MODELS[0]], "is_public": True, "groups": [], + "api_key_changed": True, }, ) assert response.status_code == 200 @@ -81,6 +87,7 @@ def test_update_llm_provider_model_names(reset: None) -> None: "id": created_provider["id"], "name": name, "provider": created_provider["provider"], + "api_key": "sk-000000000000000000000000000000000000000000000001", "default_model_name": _DEFAULT_MODELS[0], "model_names": _DEFAULT_MODELS, "is_public": True, @@ -93,6 +100,28 @@ def test_update_llm_provider_model_names(reset: None) -> None: provider_data = _get_provider_by_id(admin_user, created_provider["id"]) assert provider_data is not None assert provider_data["model_names"] == _DEFAULT_MODELS + assert ( + provider_data["api_key"] == "sk-0****0000" + ) # test that key was NOT updated due to api_key_changed not being set + + # Update with api_key_changed properly set + response = requests.put( + f"{API_SERVER_URL}/admin/llm/provider", + headers=admin_user.headers, + json={ + "id": created_provider["id"], + "name": name, + "provider": created_provider["provider"], + "api_key": "sk-000000000000000000000000000000000000000000000001", + "default_model_name": _DEFAULT_MODELS[0], + "model_names": _DEFAULT_MODELS, + "is_public": True, + "groups": [], + "api_key_changed": True, + }, + ) + assert response.status_code == 200 + assert provider_data["api_key"] == "sk-0****0000" # test that key was updated def test_delete_llm_provider(reset: None) -> None: @@ -107,6 +136,7 @@ def test_delete_llm_provider(reset: None) -> None: json={ "name": "test-provider-delete", "provider": "openai", + "api_key": "sk-000000000000000000000000000000000000000000000000", "default_model_name": _DEFAULT_MODELS[0], "model_names": _DEFAULT_MODELS, "is_public": True, diff --git a/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx b/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx index 3a5565013..8aab3451d 100644 --- a/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx +++ b/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx @@ -73,6 +73,7 @@ export function LLMProviderUpdateForm({ defaultModelsByProvider[llmProviderDescriptor.name] || [], deployment_name: existingLlmProvider?.deployment_name, + api_key_changed: false, }; // Setup validation schema if required @@ -113,6 +114,7 @@ export function LLMProviderUpdateForm({ is_public: Yup.boolean().required(), groups: Yup.array().of(Yup.number()), display_model_names: Yup.array().of(Yup.string()), + api_key_changed: Yup.boolean(), }); return ( @@ -122,6 +124,8 @@ export function LLMProviderUpdateForm({ onSubmit={async (values, { setSubmitting }) => { setSubmitting(true); + values.api_key_changed = values.api_key !== initialValues.api_key; + // test the configuration if (!isEqual(values, initialValues)) { setIsTesting(true);