sanitize llm keys and handle updates properly

This commit is contained in:
Richard Kuo (Danswer) 2025-03-12 15:19:59 -07:00
parent a9e5ae2f11
commit 660021bf29
9 changed files with 89 additions and 25 deletions

View File

@ -215,6 +215,7 @@ def configure_default_api_keys(db_session: Session) -> None:
fast_default_model_name="claude-3-5-sonnet-20241022",
model_names=ANTHROPIC_MODEL_NAMES,
display_model_names=["claude-3-5-sonnet-20241022"],
api_key_changed=True,
)
try:
full_provider = upsert_llm_provider(anthropic_provider, db_session)
@ -227,7 +228,7 @@ def configure_default_api_keys(db_session: Session) -> None:
)
if OPENAI_DEFAULT_API_KEY:
open_provider = LLMProviderUpsertRequest(
openai_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
@ -235,9 +236,10 @@ def configure_default_api_keys(db_session: Session) -> None:
fast_default_model_name="gpt-4o-mini",
model_names=OPEN_AI_MODEL_NAMES,
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
api_key_changed=True,
)
try:
full_provider = upsert_llm_provider(open_provider, db_session)
full_provider = upsert_llm_provider(openai_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure OpenAI provider: {e}")

View File

@ -15,8 +15,8 @@ from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from onyx.server.manage.llm.models import FullLLMProvider
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from shared_configs.enums import EmbeddingProvider
@ -66,7 +66,7 @@ def upsert_cloud_embedding_provider(
def upsert_llm_provider(
llm_provider: LLMProviderUpsertRequest,
db_session: Session,
) -> FullLLMProvider:
) -> LLMProviderView:
existing_llm_provider = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
)
@ -97,7 +97,7 @@ def upsert_llm_provider(
group_ids=llm_provider.groups,
db_session=db_session,
)
full_llm_provider = FullLLMProvider.from_model(existing_llm_provider)
full_llm_provider = LLMProviderView.from_model(existing_llm_provider)
db_session.commit()
@ -131,6 +131,16 @@ def fetch_existing_llm_providers(
return list(db_session.scalars(stmt).all())
def fetch_existing_llm_provider(
provider_name: str, db_session: Session
) -> LLMProviderModel | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
)
return provider_model
def fetch_existing_llm_providers_for_user(
db_session: Session,
user: User | None = None,
@ -176,7 +186,7 @@ def fetch_embedding_provider(
)
def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
def fetch_default_provider(db_session: Session) -> LLMProviderView | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.is_default_provider == True # noqa: E712
@ -184,16 +194,18 @@ def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
)
if not provider_model:
return None
return FullLLMProvider.from_model(provider_model)
return LLMProviderView.from_model(provider_model)
def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None:
def fetch_llm_provider_view(
db_session: Session, provider_name: str
) -> LLMProviderView | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
)
if not provider_model:
return None
return FullLLMProvider.from_model(provider_model)
return LLMProviderView.from_model(provider_model)
def remove_embedding_provider(

View File

@ -7,7 +7,7 @@ from onyx.configs.model_configs import GEN_AI_TEMPERATURE
from onyx.db.engine import get_session_context_manager
from onyx.db.llm import fetch_default_provider
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_provider
from onyx.db.llm import fetch_llm_provider_view
from onyx.db.models import Persona
from onyx.llm.chat_llm import DefaultMultiLLM
from onyx.llm.exceptions import GenAIDisabledException
@ -59,7 +59,7 @@ def get_llms_for_persona(
)
with get_session_context_manager() as db_session:
llm_provider = fetch_provider(db_session, provider_name)
llm_provider = fetch_llm_provider_view(db_session, provider_name)
if not llm_provider:
raise ValueError("No LLM provider found")

View File

@ -9,9 +9,9 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accessible_user
from onyx.db.engine import get_session
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_llm_providers_for_user
from onyx.db.llm import fetch_provider
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
@ -22,9 +22,9 @@ from onyx.llm.llm_provider_options import fetch_available_well_known_llms
from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.llm.utils import test_llm
from onyx.server.manage.llm.models import FullLLMProvider
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import TestLLMRequest
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@ -116,11 +116,17 @@ def test_default_provider(
def list_llm_providers(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[FullLLMProvider]:
return [
FullLLMProvider.from_model(llm_provider_model)
for llm_provider_model in fetch_existing_llm_providers(db_session)
]
) -> list[LLMProviderView]:
llm_provider_list: list[LLMProviderView] = []
for llm_provider_model in fetch_existing_llm_providers(db_session):
full_llm_provider = LLMProviderView.from_model(llm_provider_model)
if full_llm_provider.api_key:
full_llm_provider.api_key = (
full_llm_provider.api_key[:4] + "****" + full_llm_provider.api_key[-4:]
)
llm_provider_list.append(full_llm_provider)
return llm_provider_list
@admin_router.put("/provider")
@ -132,11 +138,11 @@ def put_llm_provider(
),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> FullLLMProvider:
) -> LLMProviderView:
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
# the result
existing_provider = fetch_provider(db_session, llm_provider.name)
existing_provider = fetch_existing_llm_provider(llm_provider.name, db_session)
if existing_provider and is_creation:
raise HTTPException(
status_code=400,
@ -158,6 +164,11 @@ def put_llm_provider(
llm_provider.fast_default_model_name
)
# the llm api key is sanitized when returned to clients, so the only time we
# should get a real key is when it is explicitly changed
if existing_provider and not llm_provider.api_key_changed:
llm_provider.api_key = existing_provider.api_key
try:
return upsert_llm_provider(
llm_provider=llm_provider,

View File

@ -74,15 +74,18 @@ class LLMProviderUpsertRequest(LLMProvider):
# should only be used for a "custom" provider
# for default providers, the built-in model names are used
model_names: list[str] | None = None
api_key_changed: bool = False
class FullLLMProvider(LLMProvider):
class LLMProviderView(LLMProvider):
"""Stripped down representation of LLMProvider for display / limited access info only"""
id: int
is_default_provider: bool | None = None
model_names: list[str]
@classmethod
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "FullLLMProvider":
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "LLMProviderView":
return cls(
id=llm_provider_model.id,
name=llm_provider_model.name,

View File

@ -307,6 +307,7 @@ def setup_postgres(db_session: Session) -> None:
groups=[],
display_model_names=OPEN_AI_MODEL_NAMES,
model_names=OPEN_AI_MODEL_NAMES,
api_key_changed=True,
)
new_llm_provider = upsert_llm_provider(
llm_provider=model_req, db_session=db_session

View File

@ -3,8 +3,8 @@ from uuid import uuid4
import requests
from onyx.server.manage.llm.models import FullLLMProvider
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestLLMProvider
@ -39,6 +39,7 @@ class LLMProviderManager:
groups=groups or [],
display_model_names=None,
model_names=None,
api_key_changed=True,
)
llm_response = requests.put(
@ -90,7 +91,7 @@ class LLMProviderManager:
@staticmethod
def get_all(
user_performing_action: DATestUser | None = None,
) -> list[FullLLMProvider]:
) -> list[LLMProviderView]:
response = requests.get(
f"{API_SERVER_URL}/admin/llm/provider",
headers=user_performing_action.headers
@ -98,7 +99,7 @@ class LLMProviderManager:
else GENERAL_HEADERS,
)
response.raise_for_status()
return [FullLLMProvider(**ug) for ug in response.json()]
return [LLMProviderView(**ug) for ug in response.json()]
@staticmethod
def verify(

View File

@ -34,6 +34,7 @@ def test_create_llm_provider_without_display_model_names(reset: None) -> None:
json={
"name": str(uuid.uuid4()),
"provider": "openai",
"api_key": "sk-000000000000000000000000000000000000000000000000",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,
@ -49,6 +50,9 @@ def test_create_llm_provider_without_display_model_names(reset: None) -> None:
assert provider_data["model_names"] == _DEFAULT_MODELS
assert provider_data["default_model_name"] == _DEFAULT_MODELS[0]
assert provider_data["display_model_names"] is None
assert (
provider_data["api_key"] == "sk-0****0000"
) # test that returned key is sanitized
def test_update_llm_provider_model_names(reset: None) -> None:
@ -64,10 +68,12 @@ def test_update_llm_provider_model_names(reset: None) -> None:
json={
"name": name,
"provider": "openai",
"api_key": "sk-000000000000000000000000000000000000000000000000",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": [_DEFAULT_MODELS[0]],
"is_public": True,
"groups": [],
"api_key_changed": True,
},
)
assert response.status_code == 200
@ -81,6 +87,7 @@ def test_update_llm_provider_model_names(reset: None) -> None:
"id": created_provider["id"],
"name": name,
"provider": created_provider["provider"],
"api_key": "sk-000000000000000000000000000000000000000000000001",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,
@ -93,6 +100,28 @@ def test_update_llm_provider_model_names(reset: None) -> None:
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
assert provider_data is not None
assert provider_data["model_names"] == _DEFAULT_MODELS
assert (
provider_data["api_key"] == "sk-0****0000"
) # test that key was NOT updated due to api_key_changed not being set
# Update with api_key_changed properly set
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
json={
"id": created_provider["id"],
"name": name,
"provider": created_provider["provider"],
"api_key": "sk-000000000000000000000000000000000000000000000001",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,
"groups": [],
"api_key_changed": True,
},
)
assert response.status_code == 200
assert provider_data["api_key"] == "sk-0****0000" # test that key was updated
def test_delete_llm_provider(reset: None) -> None:
@ -107,6 +136,7 @@ def test_delete_llm_provider(reset: None) -> None:
json={
"name": "test-provider-delete",
"provider": "openai",
"api_key": "sk-000000000000000000000000000000000000000000000000",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,

View File

@ -73,6 +73,7 @@ export function LLMProviderUpdateForm({
defaultModelsByProvider[llmProviderDescriptor.name] ||
[],
deployment_name: existingLlmProvider?.deployment_name,
api_key_changed: false,
};
// Setup validation schema if required
@ -113,6 +114,7 @@ export function LLMProviderUpdateForm({
is_public: Yup.boolean().required(),
groups: Yup.array().of(Yup.number()),
display_model_names: Yup.array().of(Yup.string()),
api_key_changed: Yup.boolean(),
});
return (
@ -122,6 +124,8 @@ export function LLMProviderUpdateForm({
onSubmit={async (values, { setSubmitting }) => {
setSubmitting(true);
values.api_key_changed = values.api_key !== initialValues.api_key;
// test the configuration
if (!isEqual(values, initialValues)) {
setIsTesting(true);