sanitize llm keys and handle updates properly (#4270)

* sanitize llm keys and handle updates properly

* fix llm provider testing

* fix test

* mypy

* fix default model editing

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
This commit is contained in:
rkuo-danswer
2025-03-19 18:13:02 -07:00
committed by GitHub
parent 5dda53eec3
commit 85ebadc8eb
17 changed files with 146 additions and 62 deletions

View File

@@ -271,6 +271,7 @@ def configure_default_api_keys(db_session: Session) -> None:
fast_default_model_name="claude-3-5-sonnet-20241022", fast_default_model_name="claude-3-5-sonnet-20241022",
model_names=ANTHROPIC_MODEL_NAMES, model_names=ANTHROPIC_MODEL_NAMES,
display_model_names=["claude-3-5-sonnet-20241022"], display_model_names=["claude-3-5-sonnet-20241022"],
api_key_changed=True,
) )
try: try:
full_provider = upsert_llm_provider(anthropic_provider, db_session) full_provider = upsert_llm_provider(anthropic_provider, db_session)
@@ -283,7 +284,7 @@ def configure_default_api_keys(db_session: Session) -> None:
) )
if OPENAI_DEFAULT_API_KEY: if OPENAI_DEFAULT_API_KEY:
open_provider = LLMProviderUpsertRequest( openai_provider = LLMProviderUpsertRequest(
name="OpenAI", name="OpenAI",
provider=OPENAI_PROVIDER_NAME, provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY, api_key=OPENAI_DEFAULT_API_KEY,
@@ -291,9 +292,10 @@ def configure_default_api_keys(db_session: Session) -> None:
fast_default_model_name="gpt-4o-mini", fast_default_model_name="gpt-4o-mini",
model_names=OPEN_AI_MODEL_NAMES, model_names=OPEN_AI_MODEL_NAMES,
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"], display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
api_key_changed=True,
) )
try: try:
full_provider = upsert_llm_provider(open_provider, db_session) full_provider = upsert_llm_provider(openai_provider, db_session)
update_default_provider(full_provider.id, db_session) update_default_provider(full_provider.id, db_session)
except Exception as e: except Exception as e:
logger.error(f"Failed to configure OpenAI provider: {e}") logger.error(f"Failed to configure OpenAI provider: {e}")

View File

@@ -16,8 +16,8 @@ from onyx.db.models import User__UserGroup
from onyx.llm.utils import model_supports_image_input from onyx.llm.utils import model_supports_image_input
from onyx.server.manage.embedding.models import CloudEmbeddingProvider from onyx.server.manage.embedding.models import CloudEmbeddingProvider
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from onyx.server.manage.llm.models import FullLLMProvider
from onyx.server.manage.llm.models import LLMProviderUpsertRequest from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from shared_configs.enums import EmbeddingProvider from shared_configs.enums import EmbeddingProvider
@@ -67,7 +67,7 @@ def upsert_cloud_embedding_provider(
def upsert_llm_provider( def upsert_llm_provider(
llm_provider: LLMProviderUpsertRequest, llm_provider: LLMProviderUpsertRequest,
db_session: Session, db_session: Session,
) -> FullLLMProvider: ) -> LLMProviderView:
existing_llm_provider = db_session.scalar( existing_llm_provider = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name) select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
) )
@@ -98,7 +98,7 @@ def upsert_llm_provider(
group_ids=llm_provider.groups, group_ids=llm_provider.groups,
db_session=db_session, db_session=db_session,
) )
full_llm_provider = FullLLMProvider.from_model(existing_llm_provider) full_llm_provider = LLMProviderView.from_model(existing_llm_provider)
db_session.commit() db_session.commit()
@@ -132,6 +132,16 @@ def fetch_existing_llm_providers(
return list(db_session.scalars(stmt).all()) return list(db_session.scalars(stmt).all())
def fetch_existing_llm_provider(
provider_name: str, db_session: Session
) -> LLMProviderModel | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
)
return provider_model
def fetch_existing_llm_providers_for_user( def fetch_existing_llm_providers_for_user(
db_session: Session, db_session: Session,
user: User | None = None, user: User | None = None,
@@ -177,7 +187,7 @@ def fetch_embedding_provider(
) )
def fetch_default_provider(db_session: Session) -> FullLLMProvider | None: def fetch_default_provider(db_session: Session) -> LLMProviderView | None:
provider_model = db_session.scalar( provider_model = db_session.scalar(
select(LLMProviderModel).where( select(LLMProviderModel).where(
LLMProviderModel.is_default_provider == True # noqa: E712 LLMProviderModel.is_default_provider == True # noqa: E712
@@ -185,10 +195,10 @@ def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
) )
if not provider_model: if not provider_model:
return None return None
return FullLLMProvider.from_model(provider_model) return LLMProviderView.from_model(provider_model)
def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None: def fetch_default_vision_provider(db_session: Session) -> LLMProviderView | None:
provider_model = db_session.scalar( provider_model = db_session.scalar(
select(LLMProviderModel).where( select(LLMProviderModel).where(
LLMProviderModel.is_default_vision_provider == True # noqa: E712 LLMProviderModel.is_default_vision_provider == True # noqa: E712
@@ -196,16 +206,18 @@ def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None
) )
if not provider_model: if not provider_model:
return None return None
return FullLLMProvider.from_model(provider_model) return LLMProviderView.from_model(provider_model)
def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None: def fetch_llm_provider_view(
db_session: Session, provider_name: str
) -> LLMProviderView | None:
provider_model = db_session.scalar( provider_model = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == provider_name) select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
) )
if not provider_model: if not provider_model:
return None return None
return FullLLMProvider.from_model(provider_model) return LLMProviderView.from_model(provider_model)
def remove_embedding_provider( def remove_embedding_provider(

View File

@@ -9,14 +9,14 @@ from onyx.db.engine import get_session_with_current_tenant
from onyx.db.llm import fetch_default_provider from onyx.db.llm import fetch_default_provider
from onyx.db.llm import fetch_default_vision_provider from onyx.db.llm import fetch_default_vision_provider
from onyx.db.llm import fetch_existing_llm_providers from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_provider from onyx.db.llm import fetch_llm_provider_view
from onyx.db.models import Persona from onyx.db.models import Persona
from onyx.llm.chat_llm import DefaultMultiLLM from onyx.llm.chat_llm import DefaultMultiLLM
from onyx.llm.exceptions import GenAIDisabledException from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.interfaces import LLM from onyx.llm.interfaces import LLM
from onyx.llm.override_models import LLMOverride from onyx.llm.override_models import LLMOverride
from onyx.llm.utils import model_supports_image_input from onyx.llm.utils import model_supports_image_input
from onyx.server.manage.llm.models import FullLLMProvider from onyx.server.manage.llm.models import LLMProviderView
from onyx.utils.headers import build_llm_extra_headers from onyx.utils.headers import build_llm_extra_headers
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger from onyx.utils.long_term_log import LongTermLogger
@@ -62,7 +62,7 @@ def get_llms_for_persona(
) )
with get_session_context_manager() as db_session: with get_session_context_manager() as db_session:
llm_provider = fetch_provider(db_session, provider_name) llm_provider = fetch_llm_provider_view(db_session, provider_name)
if not llm_provider: if not llm_provider:
raise ValueError("No LLM provider found") raise ValueError("No LLM provider found")
@@ -106,7 +106,7 @@ def get_default_llm_with_vision(
if DISABLE_GENERATIVE_AI: if DISABLE_GENERATIVE_AI:
raise GenAIDisabledException() raise GenAIDisabledException()
def create_vision_llm(provider: FullLLMProvider, model: str) -> LLM: def create_vision_llm(provider: LLMProviderView, model: str) -> LLM:
"""Helper to create an LLM if the provider supports image input.""" """Helper to create an LLM if the provider supports image input."""
return get_llm( return get_llm(
provider=provider.provider, provider=provider.provider,
@@ -148,7 +148,7 @@ def get_default_llm_with_vision(
provider.default_vision_model, provider.provider provider.default_vision_model, provider.provider
): ):
return create_vision_llm( return create_vision_llm(
FullLLMProvider.from_model(provider), provider.default_vision_model LLMProviderView.from_model(provider), provider.default_vision_model
) )
return None return None

View File

@@ -9,9 +9,9 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accessible_user from onyx.auth.users import current_chat_accessible_user
from onyx.db.engine import get_session from onyx.db.engine import get_session
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_providers from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_llm_providers_for_user from onyx.db.llm import fetch_existing_llm_providers_for_user
from onyx.db.llm import fetch_provider
from onyx.db.llm import remove_llm_provider from onyx.db.llm import remove_llm_provider
from onyx.db.llm import update_default_provider from onyx.db.llm import update_default_provider
from onyx.db.llm import update_default_vision_provider from onyx.db.llm import update_default_vision_provider
@@ -24,9 +24,9 @@ from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
from onyx.llm.utils import litellm_exception_to_error_msg from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.llm.utils import model_supports_image_input from onyx.llm.utils import model_supports_image_input
from onyx.llm.utils import test_llm from onyx.llm.utils import test_llm
from onyx.server.manage.llm.models import FullLLMProvider
from onyx.server.manage.llm.models import LLMProviderDescriptor from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderUpsertRequest from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import TestLLMRequest from onyx.server.manage.llm.models import TestLLMRequest
from onyx.server.manage.llm.models import VisionProviderResponse from onyx.server.manage.llm.models import VisionProviderResponse
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
@@ -49,11 +49,27 @@ def fetch_llm_options(
def test_llm_configuration( def test_llm_configuration(
test_llm_request: TestLLMRequest, test_llm_request: TestLLMRequest,
_: User | None = Depends(current_admin_user), _: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None: ) -> None:
"""Test regular llm and fast llm settings"""
# the api key is sanitized if we are testing a provider already in the system
test_api_key = test_llm_request.api_key
if test_llm_request.name:
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
# as it turns out the name is not editable in the UI and other code also keys off name,
# so we won't rock the boat just yet.
existing_provider = fetch_existing_llm_provider(
test_llm_request.name, db_session
)
if existing_provider:
test_api_key = existing_provider.api_key
llm = get_llm( llm = get_llm(
provider=test_llm_request.provider, provider=test_llm_request.provider,
model=test_llm_request.default_model_name, model=test_llm_request.default_model_name,
api_key=test_llm_request.api_key, api_key=test_api_key,
api_base=test_llm_request.api_base, api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version, api_version=test_llm_request.api_version,
custom_config=test_llm_request.custom_config, custom_config=test_llm_request.custom_config,
@@ -69,7 +85,7 @@ def test_llm_configuration(
fast_llm = get_llm( fast_llm = get_llm(
provider=test_llm_request.provider, provider=test_llm_request.provider,
model=test_llm_request.fast_default_model_name, model=test_llm_request.fast_default_model_name,
api_key=test_llm_request.api_key, api_key=test_api_key,
api_base=test_llm_request.api_base, api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version, api_version=test_llm_request.api_version,
custom_config=test_llm_request.custom_config, custom_config=test_llm_request.custom_config,
@@ -119,11 +135,17 @@ def test_default_provider(
def list_llm_providers( def list_llm_providers(
_: User | None = Depends(current_admin_user), _: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session), db_session: Session = Depends(get_session),
) -> list[FullLLMProvider]: ) -> list[LLMProviderView]:
return [ llm_provider_list: list[LLMProviderView] = []
FullLLMProvider.from_model(llm_provider_model) for llm_provider_model in fetch_existing_llm_providers(db_session):
for llm_provider_model in fetch_existing_llm_providers(db_session) full_llm_provider = LLMProviderView.from_model(llm_provider_model)
] if full_llm_provider.api_key:
full_llm_provider.api_key = (
full_llm_provider.api_key[:4] + "****" + full_llm_provider.api_key[-4:]
)
llm_provider_list.append(full_llm_provider)
return llm_provider_list
@admin_router.put("/provider") @admin_router.put("/provider")
@@ -135,11 +157,11 @@ def put_llm_provider(
), ),
_: User | None = Depends(current_admin_user), _: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session), db_session: Session = Depends(get_session),
) -> FullLLMProvider: ) -> LLMProviderView:
# validate request (e.g. if we're intending to create but the name already exists we should throw an error) # validate request (e.g. if we're intending to create but the name already exists we should throw an error)
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache # NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
# the result # the result
existing_provider = fetch_provider(db_session, llm_provider.name) existing_provider = fetch_existing_llm_provider(llm_provider.name, db_session)
if existing_provider and is_creation: if existing_provider and is_creation:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
@@ -161,6 +183,11 @@ def put_llm_provider(
llm_provider.fast_default_model_name llm_provider.fast_default_model_name
) )
# the llm api key is sanitized when returned to clients, so the only time we
# should get a real key is when it is explicitly changed
if existing_provider and not llm_provider.api_key_changed:
llm_provider.api_key = existing_provider.api_key
try: try:
return upsert_llm_provider( return upsert_llm_provider(
llm_provider=llm_provider, llm_provider=llm_provider,
@@ -234,7 +261,7 @@ def get_vision_capable_providers(
# Only include providers with at least one vision-capable model # Only include providers with at least one vision-capable model
if vision_models: if vision_models:
provider_dict = FullLLMProvider.from_model(provider).model_dump() provider_dict = LLMProviderView.from_model(provider).model_dump()
provider_dict["vision_models"] = vision_models provider_dict["vision_models"] = vision_models
logger.info( logger.info(
f"Vision provider: {provider.provider} with models: {vision_models}" f"Vision provider: {provider.provider} with models: {vision_models}"

View File

@@ -12,6 +12,7 @@ if TYPE_CHECKING:
class TestLLMRequest(BaseModel): class TestLLMRequest(BaseModel):
# provider level # provider level
name: str | None = None
provider: str provider: str
api_key: str | None = None api_key: str | None = None
api_base: str | None = None api_base: str | None = None
@@ -76,16 +77,19 @@ class LLMProviderUpsertRequest(LLMProvider):
# should only be used for a "custom" provider # should only be used for a "custom" provider
# for default providers, the built-in model names are used # for default providers, the built-in model names are used
model_names: list[str] | None = None model_names: list[str] | None = None
api_key_changed: bool = False
class FullLLMProvider(LLMProvider): class LLMProviderView(LLMProvider):
"""Stripped down representation of LLMProvider for display / limited access info only"""
id: int id: int
is_default_provider: bool | None = None is_default_provider: bool | None = None
is_default_vision_provider: bool | None = None is_default_vision_provider: bool | None = None
model_names: list[str] model_names: list[str]
@classmethod @classmethod
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "FullLLMProvider": def from_model(cls, llm_provider_model: "LLMProviderModel") -> "LLMProviderView":
return cls( return cls(
id=llm_provider_model.id, id=llm_provider_model.id,
name=llm_provider_model.name, name=llm_provider_model.name,
@@ -111,7 +115,7 @@ class FullLLMProvider(LLMProvider):
) )
class VisionProviderResponse(FullLLMProvider): class VisionProviderResponse(LLMProviderView):
"""Response model for vision providers endpoint, including vision-specific fields.""" """Response model for vision providers endpoint, including vision-specific fields."""
vision_models: list[str] vision_models: list[str]

View File

@@ -307,6 +307,7 @@ def setup_postgres(db_session: Session) -> None:
groups=[], groups=[],
display_model_names=OPEN_AI_MODEL_NAMES, display_model_names=OPEN_AI_MODEL_NAMES,
model_names=OPEN_AI_MODEL_NAMES, model_names=OPEN_AI_MODEL_NAMES,
api_key_changed=True,
) )
new_llm_provider = upsert_llm_provider( new_llm_provider = upsert_llm_provider(
llm_provider=model_req, db_session=db_session llm_provider=model_req, db_session=db_session

View File

@@ -3,8 +3,8 @@ from uuid import uuid4
import requests import requests
from onyx.server.manage.llm.models import FullLLMProvider
from onyx.server.manage.llm.models import LLMProviderUpsertRequest from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestLLMProvider from tests.integration.common_utils.test_models import DATestLLMProvider
@@ -39,6 +39,7 @@ class LLMProviderManager:
groups=groups or [], groups=groups or [],
display_model_names=None, display_model_names=None,
model_names=None, model_names=None,
api_key_changed=True,
) )
llm_response = requests.put( llm_response = requests.put(
@@ -90,7 +91,7 @@ class LLMProviderManager:
@staticmethod @staticmethod
def get_all( def get_all(
user_performing_action: DATestUser | None = None, user_performing_action: DATestUser | None = None,
) -> list[FullLLMProvider]: ) -> list[LLMProviderView]:
response = requests.get( response = requests.get(
f"{API_SERVER_URL}/admin/llm/provider", f"{API_SERVER_URL}/admin/llm/provider",
headers=user_performing_action.headers headers=user_performing_action.headers
@@ -98,7 +99,7 @@ class LLMProviderManager:
else GENERAL_HEADERS, else GENERAL_HEADERS,
) )
response.raise_for_status() response.raise_for_status()
return [FullLLMProvider(**ug) for ug in response.json()] return [LLMProviderView(**ug) for ug in response.json()]
@staticmethod @staticmethod
def verify( def verify(
@@ -111,18 +112,19 @@ class LLMProviderManager:
if llm_provider.id == fetched_llm_provider.id: if llm_provider.id == fetched_llm_provider.id:
if verify_deleted: if verify_deleted:
raise ValueError( raise ValueError(
f"User group {llm_provider.id} found but should be deleted" f"LLM Provider {llm_provider.id} found but should be deleted"
) )
fetched_llm_groups = set(fetched_llm_provider.groups) fetched_llm_groups = set(fetched_llm_provider.groups)
llm_provider_groups = set(llm_provider.groups) llm_provider_groups = set(llm_provider.groups)
# NOTE: returned api keys are sanitized and should not match
if ( if (
fetched_llm_groups == llm_provider_groups fetched_llm_groups == llm_provider_groups
and llm_provider.provider == fetched_llm_provider.provider and llm_provider.provider == fetched_llm_provider.provider
and llm_provider.api_key == fetched_llm_provider.api_key
and llm_provider.default_model_name and llm_provider.default_model_name
== fetched_llm_provider.default_model_name == fetched_llm_provider.default_model_name
and llm_provider.is_public == fetched_llm_provider.is_public and llm_provider.is_public == fetched_llm_provider.is_public
): ):
return return
if not verify_deleted: if not verify_deleted:
raise ValueError(f"User group {llm_provider.id} not found") raise ValueError(f"LLM Provider {llm_provider.id} not found")

View File

@@ -34,6 +34,7 @@ def test_create_llm_provider_without_display_model_names(reset: None) -> None:
json={ json={
"name": str(uuid.uuid4()), "name": str(uuid.uuid4()),
"provider": "openai", "provider": "openai",
"api_key": "sk-000000000000000000000000000000000000000000000000",
"default_model_name": _DEFAULT_MODELS[0], "default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS, "model_names": _DEFAULT_MODELS,
"is_public": True, "is_public": True,
@@ -49,6 +50,9 @@ def test_create_llm_provider_without_display_model_names(reset: None) -> None:
assert provider_data["model_names"] == _DEFAULT_MODELS assert provider_data["model_names"] == _DEFAULT_MODELS
assert provider_data["default_model_name"] == _DEFAULT_MODELS[0] assert provider_data["default_model_name"] == _DEFAULT_MODELS[0]
assert provider_data["display_model_names"] is None assert provider_data["display_model_names"] is None
assert (
provider_data["api_key"] == "sk-0****0000"
) # test that returned key is sanitized
def test_update_llm_provider_model_names(reset: None) -> None: def test_update_llm_provider_model_names(reset: None) -> None:
@@ -64,10 +68,12 @@ def test_update_llm_provider_model_names(reset: None) -> None:
json={ json={
"name": name, "name": name,
"provider": "openai", "provider": "openai",
"api_key": "sk-000000000000000000000000000000000000000000000000",
"default_model_name": _DEFAULT_MODELS[0], "default_model_name": _DEFAULT_MODELS[0],
"model_names": [_DEFAULT_MODELS[0]], "model_names": [_DEFAULT_MODELS[0]],
"is_public": True, "is_public": True,
"groups": [], "groups": [],
"api_key_changed": True,
}, },
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -81,6 +87,7 @@ def test_update_llm_provider_model_names(reset: None) -> None:
"id": created_provider["id"], "id": created_provider["id"],
"name": name, "name": name,
"provider": created_provider["provider"], "provider": created_provider["provider"],
"api_key": "sk-000000000000000000000000000000000000000000000001",
"default_model_name": _DEFAULT_MODELS[0], "default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS, "model_names": _DEFAULT_MODELS,
"is_public": True, "is_public": True,
@@ -93,6 +100,30 @@ def test_update_llm_provider_model_names(reset: None) -> None:
provider_data = _get_provider_by_id(admin_user, created_provider["id"]) provider_data = _get_provider_by_id(admin_user, created_provider["id"])
assert provider_data is not None assert provider_data is not None
assert provider_data["model_names"] == _DEFAULT_MODELS assert provider_data["model_names"] == _DEFAULT_MODELS
assert (
provider_data["api_key"] == "sk-0****0000"
) # test that key was NOT updated due to api_key_changed not being set
# Update with api_key_changed properly set
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
json={
"id": created_provider["id"],
"name": name,
"provider": created_provider["provider"],
"api_key": "sk-000000000000000000000000000000000000000000000001",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,
"groups": [],
"api_key_changed": True,
},
)
assert response.status_code == 200
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
assert provider_data is not None
assert provider_data["api_key"] == "sk-0****0001" # test that key was updated
def test_delete_llm_provider(reset: None) -> None: def test_delete_llm_provider(reset: None) -> None:
@@ -107,6 +138,7 @@ def test_delete_llm_provider(reset: None) -> None:
json={ json={
"name": "test-provider-delete", "name": "test-provider-delete",
"provider": "openai", "provider": "openai",
"api_key": "sk-000000000000000000000000000000000000000000000000",
"default_model_name": _DEFAULT_MODELS[0], "default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS, "model_names": _DEFAULT_MODELS,
"is_public": True, "is_public": True,

View File

@@ -61,7 +61,7 @@ import {
import { buildImgUrl } from "@/app/chat/files/images/utils"; import { buildImgUrl } from "@/app/chat/files/images/utils";
import { useAssistants } from "@/components/context/AssistantsContext"; import { useAssistants } from "@/components/context/AssistantsContext";
import { debounce } from "lodash"; import { debounce } from "lodash";
import { FullLLMProvider } from "../configuration/llm/interfaces"; import { LLMProviderView } from "../configuration/llm/interfaces";
import StarterMessagesList from "./StarterMessageList"; import StarterMessagesList from "./StarterMessageList";
import { Switch, SwitchField } from "@/components/ui/switch"; import { Switch, SwitchField } from "@/components/ui/switch";
@@ -123,7 +123,7 @@ export function AssistantEditor({
documentSets: DocumentSet[]; documentSets: DocumentSet[];
user: User | null; user: User | null;
defaultPublic: boolean; defaultPublic: boolean;
llmProviders: FullLLMProvider[]; llmProviders: LLMProviderView[];
tools: ToolSnapshot[]; tools: ToolSnapshot[];
shouldAddAssistantToUserPreferences?: boolean; shouldAddAssistantToUserPreferences?: boolean;
admin?: boolean; admin?: boolean;

View File

@@ -1,4 +1,4 @@
import { FullLLMProvider } from "../configuration/llm/interfaces"; import { LLMProviderView } from "../configuration/llm/interfaces";
import { Persona, StarterMessage } from "./interfaces"; import { Persona, StarterMessage } from "./interfaces";
interface PersonaUpsertRequest { interface PersonaUpsertRequest {
@@ -319,7 +319,7 @@ export function checkPersonaRequiresImageGeneration(persona: Persona) {
} }
export function providersContainImageGeneratingSupport( export function providersContainImageGeneratingSupport(
providers: FullLLMProvider[] providers: LLMProviderView[]
) { ) {
return providers.some((provider) => provider.provider === "openai"); return providers.some((provider) => provider.provider === "openai");
} }

View File

@@ -1,5 +1,5 @@
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup"; import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces"; import { LLMProviderView, WellKnownLLMProviderDescriptor } from "./interfaces";
import { Modal } from "@/components/Modal"; import { Modal } from "@/components/Modal";
import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm"; import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm";
import { CustomLLMProviderUpdateForm } from "./CustomLLMProviderUpdateForm"; import { CustomLLMProviderUpdateForm } from "./CustomLLMProviderUpdateForm";
@@ -19,7 +19,7 @@ function LLMProviderUpdateModal({
}: { }: {
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null | undefined; llmProviderDescriptor: WellKnownLLMProviderDescriptor | null | undefined;
onClose: () => void; onClose: () => void;
existingLlmProvider?: FullLLMProvider; existingLlmProvider?: LLMProviderView;
shouldMarkAsDefault?: boolean; shouldMarkAsDefault?: boolean;
setPopup?: (popup: PopupSpec) => void; setPopup?: (popup: PopupSpec) => void;
}) { }) {
@@ -61,7 +61,7 @@ function LLMProviderDisplay({
shouldMarkAsDefault, shouldMarkAsDefault,
}: { }: {
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null | undefined; llmProviderDescriptor: WellKnownLLMProviderDescriptor | null | undefined;
existingLlmProvider: FullLLMProvider; existingLlmProvider: LLMProviderView;
shouldMarkAsDefault?: boolean; shouldMarkAsDefault?: boolean;
}) { }) {
const [formIsVisible, setFormIsVisible] = useState(false); const [formIsVisible, setFormIsVisible] = useState(false);
@@ -146,7 +146,7 @@ export function ConfiguredLLMProviderDisplay({
existingLlmProviders, existingLlmProviders,
llmProviderDescriptors, llmProviderDescriptors,
}: { }: {
existingLlmProviders: FullLLMProvider[]; existingLlmProviders: LLMProviderView[];
llmProviderDescriptors: WellKnownLLMProviderDescriptor[]; llmProviderDescriptors: WellKnownLLMProviderDescriptor[];
}) { }) {
existingLlmProviders = existingLlmProviders.sort((a, b) => { existingLlmProviders = existingLlmProviders.sort((a, b) => {

View File

@@ -21,7 +21,7 @@ import {
} from "@/components/admin/connectors/Field"; } from "@/components/admin/connectors/Field";
import { useState } from "react"; import { useState } from "react";
import { useSWRConfig } from "swr"; import { useSWRConfig } from "swr";
import { FullLLMProvider } from "./interfaces"; import { LLMProviderView } from "./interfaces";
import { PopupSpec } from "@/components/admin/connectors/Popup"; import { PopupSpec } from "@/components/admin/connectors/Popup";
import * as Yup from "yup"; import * as Yup from "yup";
import isEqual from "lodash/isEqual"; import isEqual from "lodash/isEqual";
@@ -43,7 +43,7 @@ export function CustomLLMProviderUpdateForm({
hideSuccess, hideSuccess,
}: { }: {
onClose: () => void; onClose: () => void;
existingLlmProvider?: FullLLMProvider; existingLlmProvider?: LLMProviderView;
shouldMarkAsDefault?: boolean; shouldMarkAsDefault?: boolean;
setPopup?: (popup: PopupSpec) => void; setPopup?: (popup: PopupSpec) => void;
hideSuccess?: boolean; hideSuccess?: boolean;
@@ -165,7 +165,7 @@ export function CustomLLMProviderUpdateForm({
} }
if (shouldMarkAsDefault) { if (shouldMarkAsDefault) {
const newLlmProvider = (await response.json()) as FullLLMProvider; const newLlmProvider = (await response.json()) as LLMProviderView;
const setDefaultResponse = await fetch( const setDefaultResponse = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`, `${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
{ {

View File

@@ -9,7 +9,7 @@ import Text from "@/components/ui/text";
import Title from "@/components/ui/title"; import Title from "@/components/ui/title";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { ThreeDotsLoader } from "@/components/Loading"; import { ThreeDotsLoader } from "@/components/Loading";
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces"; import { LLMProviderView, WellKnownLLMProviderDescriptor } from "./interfaces";
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup"; import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm"; import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm";
import { LLM_PROVIDERS_ADMIN_URL } from "./constants"; import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
@@ -25,7 +25,7 @@ function LLMProviderUpdateModal({
}: { }: {
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null; llmProviderDescriptor: WellKnownLLMProviderDescriptor | null;
onClose: () => void; onClose: () => void;
existingLlmProvider?: FullLLMProvider; existingLlmProvider?: LLMProviderView;
shouldMarkAsDefault?: boolean; shouldMarkAsDefault?: boolean;
setPopup?: (popup: PopupSpec) => void; setPopup?: (popup: PopupSpec) => void;
}) { }) {
@@ -99,7 +99,7 @@ function DefaultLLMProviderDisplay({
function AddCustomLLMProvider({ function AddCustomLLMProvider({
existingLlmProviders, existingLlmProviders,
}: { }: {
existingLlmProviders: FullLLMProvider[]; existingLlmProviders: LLMProviderView[];
}) { }) {
const [formIsVisible, setFormIsVisible] = useState(false); const [formIsVisible, setFormIsVisible] = useState(false);
@@ -130,7 +130,7 @@ export function LLMConfiguration() {
const { data: llmProviderDescriptors } = useSWR< const { data: llmProviderDescriptors } = useSWR<
WellKnownLLMProviderDescriptor[] WellKnownLLMProviderDescriptor[]
>("/api/admin/llm/built-in/options", errorHandlingFetcher); >("/api/admin/llm/built-in/options", errorHandlingFetcher);
const { data: existingLlmProviders } = useSWR<FullLLMProvider[]>( const { data: existingLlmProviders } = useSWR<LLMProviderView[]>(
LLM_PROVIDERS_ADMIN_URL, LLM_PROVIDERS_ADMIN_URL,
errorHandlingFetcher errorHandlingFetcher
); );

View File

@@ -14,7 +14,7 @@ import {
import { useState } from "react"; import { useState } from "react";
import { useSWRConfig } from "swr"; import { useSWRConfig } from "swr";
import { defaultModelsByProvider, getDisplayNameForModel } from "@/lib/hooks"; import { defaultModelsByProvider, getDisplayNameForModel } from "@/lib/hooks";
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces"; import { LLMProviderView, WellKnownLLMProviderDescriptor } from "./interfaces";
import { PopupSpec } from "@/components/admin/connectors/Popup"; import { PopupSpec } from "@/components/admin/connectors/Popup";
import * as Yup from "yup"; import * as Yup from "yup";
import isEqual from "lodash/isEqual"; import isEqual from "lodash/isEqual";
@@ -31,7 +31,7 @@ export function LLMProviderUpdateForm({
}: { }: {
llmProviderDescriptor: WellKnownLLMProviderDescriptor; llmProviderDescriptor: WellKnownLLMProviderDescriptor;
onClose: () => void; onClose: () => void;
existingLlmProvider?: FullLLMProvider; existingLlmProvider?: LLMProviderView;
shouldMarkAsDefault?: boolean; shouldMarkAsDefault?: boolean;
hideAdvanced?: boolean; hideAdvanced?: boolean;
setPopup?: (popup: PopupSpec) => void; setPopup?: (popup: PopupSpec) => void;
@@ -73,6 +73,7 @@ export function LLMProviderUpdateForm({
defaultModelsByProvider[llmProviderDescriptor.name] || defaultModelsByProvider[llmProviderDescriptor.name] ||
[], [],
deployment_name: existingLlmProvider?.deployment_name, deployment_name: existingLlmProvider?.deployment_name,
api_key_changed: false,
}; };
// Setup validation schema if required // Setup validation schema if required
@@ -113,6 +114,7 @@ export function LLMProviderUpdateForm({
is_public: Yup.boolean().required(), is_public: Yup.boolean().required(),
groups: Yup.array().of(Yup.number()), groups: Yup.array().of(Yup.number()),
display_model_names: Yup.array().of(Yup.string()), display_model_names: Yup.array().of(Yup.string()),
api_key_changed: Yup.boolean(),
}); });
return ( return (
@@ -122,6 +124,8 @@ export function LLMProviderUpdateForm({
onSubmit={async (values, { setSubmitting }) => { onSubmit={async (values, { setSubmitting }) => {
setSubmitting(true); setSubmitting(true);
values.api_key_changed = values.api_key !== initialValues.api_key;
// test the configuration // test the configuration
if (!isEqual(values, initialValues)) { if (!isEqual(values, initialValues)) {
setIsTesting(true); setIsTesting(true);
@@ -180,7 +184,7 @@ export function LLMProviderUpdateForm({
} }
if (shouldMarkAsDefault) { if (shouldMarkAsDefault) {
const newLlmProvider = (await response.json()) as FullLLMProvider; const newLlmProvider = (await response.json()) as LLMProviderView;
const setDefaultResponse = await fetch( const setDefaultResponse = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`, `${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
{ {

View File

@@ -53,14 +53,14 @@ export interface LLMProvider {
is_default_vision_provider: boolean | null; is_default_vision_provider: boolean | null;
} }
export interface FullLLMProvider extends LLMProvider { export interface LLMProviderView extends LLMProvider {
id: number; id: number;
is_default_provider: boolean | null; is_default_provider: boolean | null;
model_names: string[]; model_names: string[];
icon?: React.FC<{ size?: number; className?: string }>; icon?: React.FC<{ size?: number; className?: string }>;
} }
export interface VisionProvider extends FullLLMProvider { export interface VisionProvider extends LLMProviderView {
vision_models: string[]; vision_models: string[];
} }

View File

@@ -1,5 +1,5 @@
import { import {
FullLLMProvider, LLMProviderView,
WellKnownLLMProviderDescriptor, WellKnownLLMProviderDescriptor,
} from "@/app/admin/configuration/llm/interfaces"; } from "@/app/admin/configuration/llm/interfaces";
import { User } from "@/lib/types"; import { User } from "@/lib/types";
@@ -36,7 +36,7 @@ export async function checkLlmProvider(user: User | null) {
const [providerResponse, optionsResponse, defaultCheckResponse] = const [providerResponse, optionsResponse, defaultCheckResponse] =
await Promise.all(tasks); await Promise.all(tasks);
let providers: FullLLMProvider[] = []; let providers: LLMProviderView[] = [];
if (providerResponse?.ok) { if (providerResponse?.ok) {
providers = await providerResponse.json(); providers = await providerResponse.json();
} }

View File

@@ -3,7 +3,7 @@ import { CCPairBasicInfo, DocumentSet, User } from "../types";
import { getCurrentUserSS } from "../userSS"; import { getCurrentUserSS } from "../userSS";
import { fetchSS } from "../utilsSS"; import { fetchSS } from "../utilsSS";
import { import {
FullLLMProvider, LLMProviderView,
getProviderIcon, getProviderIcon,
} from "@/app/admin/configuration/llm/interfaces"; } from "@/app/admin/configuration/llm/interfaces";
import { ToolSnapshot } from "../tools/interfaces"; import { ToolSnapshot } from "../tools/interfaces";
@@ -16,7 +16,7 @@ export async function fetchAssistantEditorInfoSS(
{ {
ccPairs: CCPairBasicInfo[]; ccPairs: CCPairBasicInfo[];
documentSets: DocumentSet[]; documentSets: DocumentSet[];
llmProviders: FullLLMProvider[]; llmProviders: LLMProviderView[];
user: User | null; user: User | null;
existingPersona: Persona | null; existingPersona: Persona | null;
tools: ToolSnapshot[]; tools: ToolSnapshot[];
@@ -83,7 +83,7 @@ export async function fetchAssistantEditorInfoSS(
]; ];
} }
const llmProviders = (await llmProvidersResponse.json()) as FullLLMProvider[]; const llmProviders = (await llmProvidersResponse.json()) as LLMProviderView[];
if (personaId && personaResponse && !personaResponse.ok) { if (personaId && personaResponse && !personaResponse.ok) {
return [null, `Failed to fetch Persona - ${await personaResponse.text()}`]; return [null, `Failed to fetch Persona - ${await personaResponse.text()}`];