mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-18 05:41:58 +01:00
sanitize llm keys and handle updates properly
This commit is contained in:
parent
a9e5ae2f11
commit
660021bf29
@ -215,6 +215,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_names=ANTHROPIC_MODEL_NAMES,
|
||||
display_model_names=["claude-3-5-sonnet-20241022"],
|
||||
api_key_changed=True,
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(anthropic_provider, db_session)
|
||||
@ -227,7 +228,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
)
|
||||
|
||||
if OPENAI_DEFAULT_API_KEY:
|
||||
open_provider = LLMProviderUpsertRequest(
|
||||
openai_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
@ -235,9 +236,10 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
fast_default_model_name="gpt-4o-mini",
|
||||
model_names=OPEN_AI_MODEL_NAMES,
|
||||
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
|
||||
api_key_changed=True,
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(open_provider, db_session)
|
||||
full_provider = upsert_llm_provider(openai_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure OpenAI provider: {e}")
|
||||
|
@ -15,8 +15,8 @@ from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import FullLLMProvider
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
|
||||
@ -66,7 +66,7 @@ def upsert_cloud_embedding_provider(
|
||||
def upsert_llm_provider(
|
||||
llm_provider: LLMProviderUpsertRequest,
|
||||
db_session: Session,
|
||||
) -> FullLLMProvider:
|
||||
) -> LLMProviderView:
|
||||
existing_llm_provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
|
||||
)
|
||||
@ -97,7 +97,7 @@ def upsert_llm_provider(
|
||||
group_ids=llm_provider.groups,
|
||||
db_session=db_session,
|
||||
)
|
||||
full_llm_provider = FullLLMProvider.from_model(existing_llm_provider)
|
||||
full_llm_provider = LLMProviderView.from_model(existing_llm_provider)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@ -131,6 +131,16 @@ def fetch_existing_llm_providers(
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def fetch_existing_llm_provider(
|
||||
provider_name: str, db_session: Session
|
||||
) -> LLMProviderModel | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
|
||||
)
|
||||
|
||||
return provider_model
|
||||
|
||||
|
||||
def fetch_existing_llm_providers_for_user(
|
||||
db_session: Session,
|
||||
user: User | None = None,
|
||||
@ -176,7 +186,7 @@ def fetch_embedding_provider(
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
|
||||
def fetch_default_provider(db_session: Session) -> LLMProviderView | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.is_default_provider == True # noqa: E712
|
||||
@ -184,16 +194,18 @@ def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
|
||||
)
|
||||
if not provider_model:
|
||||
return None
|
||||
return FullLLMProvider.from_model(provider_model)
|
||||
return LLMProviderView.from_model(provider_model)
|
||||
|
||||
|
||||
def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None:
|
||||
def fetch_llm_provider_view(
|
||||
db_session: Session, provider_name: str
|
||||
) -> LLMProviderView | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
|
||||
)
|
||||
if not provider_model:
|
||||
return None
|
||||
return FullLLMProvider.from_model(provider_model)
|
||||
return LLMProviderView.from_model(provider_model)
|
||||
|
||||
|
||||
def remove_embedding_provider(
|
||||
|
@ -7,7 +7,7 @@ from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.llm import fetch_default_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_provider
|
||||
from onyx.db.llm import fetch_llm_provider_view
|
||||
from onyx.db.models import Persona
|
||||
from onyx.llm.chat_llm import DefaultMultiLLM
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
@ -59,7 +59,7 @@ def get_llms_for_persona(
|
||||
)
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
llm_provider = fetch_provider(db_session, provider_name)
|
||||
llm_provider = fetch_llm_provider_view(db_session, provider_name)
|
||||
|
||||
if not llm_provider:
|
||||
raise ValueError("No LLM provider found")
|
||||
|
@ -9,9 +9,9 @@ from sqlalchemy.orm import Session
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_existing_llm_providers_for_user
|
||||
from onyx.db.llm import fetch_provider
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
@ -22,9 +22,9 @@ from onyx.llm.llm_provider_options import fetch_available_well_known_llms
|
||||
from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.llm.utils import test_llm
|
||||
from onyx.server.manage.llm.models import FullLLMProvider
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import TestLLMRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
@ -116,11 +116,17 @@ def test_default_provider(
|
||||
def list_llm_providers(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[FullLLMProvider]:
|
||||
return [
|
||||
FullLLMProvider.from_model(llm_provider_model)
|
||||
for llm_provider_model in fetch_existing_llm_providers(db_session)
|
||||
]
|
||||
) -> list[LLMProviderView]:
|
||||
llm_provider_list: list[LLMProviderView] = []
|
||||
for llm_provider_model in fetch_existing_llm_providers(db_session):
|
||||
full_llm_provider = LLMProviderView.from_model(llm_provider_model)
|
||||
if full_llm_provider.api_key:
|
||||
full_llm_provider.api_key = (
|
||||
full_llm_provider.api_key[:4] + "****" + full_llm_provider.api_key[-4:]
|
||||
)
|
||||
llm_provider_list.append(full_llm_provider)
|
||||
|
||||
return llm_provider_list
|
||||
|
||||
|
||||
@admin_router.put("/provider")
|
||||
@ -132,11 +138,11 @@ def put_llm_provider(
|
||||
),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FullLLMProvider:
|
||||
) -> LLMProviderView:
|
||||
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
|
||||
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
|
||||
# the result
|
||||
existing_provider = fetch_provider(db_session, llm_provider.name)
|
||||
existing_provider = fetch_existing_llm_provider(llm_provider.name, db_session)
|
||||
if existing_provider and is_creation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@ -158,6 +164,11 @@ def put_llm_provider(
|
||||
llm_provider.fast_default_model_name
|
||||
)
|
||||
|
||||
# the llm api key is sanitized when returned to clients, so the only time we
|
||||
# should get a real key is when it is explicitly changed
|
||||
if existing_provider and not llm_provider.api_key_changed:
|
||||
llm_provider.api_key = existing_provider.api_key
|
||||
|
||||
try:
|
||||
return upsert_llm_provider(
|
||||
llm_provider=llm_provider,
|
||||
|
@ -74,15 +74,18 @@ class LLMProviderUpsertRequest(LLMProvider):
|
||||
# should only be used for a "custom" provider
|
||||
# for default providers, the built-in model names are used
|
||||
model_names: list[str] | None = None
|
||||
api_key_changed: bool = False
|
||||
|
||||
|
||||
class FullLLMProvider(LLMProvider):
|
||||
class LLMProviderView(LLMProvider):
|
||||
"""Stripped down representation of LLMProvider for display / limited access info only"""
|
||||
|
||||
id: int
|
||||
is_default_provider: bool | None = None
|
||||
model_names: list[str]
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "FullLLMProvider":
|
||||
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "LLMProviderView":
|
||||
return cls(
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
|
@ -307,6 +307,7 @@ def setup_postgres(db_session: Session) -> None:
|
||||
groups=[],
|
||||
display_model_names=OPEN_AI_MODEL_NAMES,
|
||||
model_names=OPEN_AI_MODEL_NAMES,
|
||||
api_key_changed=True,
|
||||
)
|
||||
new_llm_provider = upsert_llm_provider(
|
||||
llm_provider=model_req, db_session=db_session
|
||||
|
@ -3,8 +3,8 @@ from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.server.manage.llm.models import FullLLMProvider
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
@ -39,6 +39,7 @@ class LLMProviderManager:
|
||||
groups=groups or [],
|
||||
display_model_names=None,
|
||||
model_names=None,
|
||||
api_key_changed=True,
|
||||
)
|
||||
|
||||
llm_response = requests.put(
|
||||
@ -90,7 +91,7 @@ class LLMProviderManager:
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> list[FullLLMProvider]:
|
||||
) -> list[LLMProviderView]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=user_performing_action.headers
|
||||
@ -98,7 +99,7 @@ class LLMProviderManager:
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [FullLLMProvider(**ug) for ug in response.json()]
|
||||
return [LLMProviderView(**ug) for ug in response.json()]
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
|
@ -34,6 +34,7 @@ def test_create_llm_provider_without_display_model_names(reset: None) -> None:
|
||||
json={
|
||||
"name": str(uuid.uuid4()),
|
||||
"provider": "openai",
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": _DEFAULT_MODELS,
|
||||
"is_public": True,
|
||||
@ -49,6 +50,9 @@ def test_create_llm_provider_without_display_model_names(reset: None) -> None:
|
||||
assert provider_data["model_names"] == _DEFAULT_MODELS
|
||||
assert provider_data["default_model_name"] == _DEFAULT_MODELS[0]
|
||||
assert provider_data["display_model_names"] is None
|
||||
assert (
|
||||
provider_data["api_key"] == "sk-0****0000"
|
||||
) # test that returned key is sanitized
|
||||
|
||||
|
||||
def test_update_llm_provider_model_names(reset: None) -> None:
|
||||
@ -64,10 +68,12 @@ def test_update_llm_provider_model_names(reset: None) -> None:
|
||||
json={
|
||||
"name": name,
|
||||
"provider": "openai",
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": [_DEFAULT_MODELS[0]],
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
"api_key_changed": True,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@ -81,6 +87,7 @@ def test_update_llm_provider_model_names(reset: None) -> None:
|
||||
"id": created_provider["id"],
|
||||
"name": name,
|
||||
"provider": created_provider["provider"],
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000001",
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": _DEFAULT_MODELS,
|
||||
"is_public": True,
|
||||
@ -93,6 +100,28 @@ def test_update_llm_provider_model_names(reset: None) -> None:
|
||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||
assert provider_data is not None
|
||||
assert provider_data["model_names"] == _DEFAULT_MODELS
|
||||
assert (
|
||||
provider_data["api_key"] == "sk-0****0000"
|
||||
) # test that key was NOT updated due to api_key_changed not being set
|
||||
|
||||
# Update with api_key_changed properly set
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"id": created_provider["id"],
|
||||
"name": name,
|
||||
"provider": created_provider["provider"],
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000001",
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": _DEFAULT_MODELS,
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
"api_key_changed": True,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert provider_data["api_key"] == "sk-0****0000" # test that key was updated
|
||||
|
||||
|
||||
def test_delete_llm_provider(reset: None) -> None:
|
||||
@ -107,6 +136,7 @@ def test_delete_llm_provider(reset: None) -> None:
|
||||
json={
|
||||
"name": "test-provider-delete",
|
||||
"provider": "openai",
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": _DEFAULT_MODELS,
|
||||
"is_public": True,
|
||||
|
@ -73,6 +73,7 @@ export function LLMProviderUpdateForm({
|
||||
defaultModelsByProvider[llmProviderDescriptor.name] ||
|
||||
[],
|
||||
deployment_name: existingLlmProvider?.deployment_name,
|
||||
api_key_changed: false,
|
||||
};
|
||||
|
||||
// Setup validation schema if required
|
||||
@ -113,6 +114,7 @@ export function LLMProviderUpdateForm({
|
||||
is_public: Yup.boolean().required(),
|
||||
groups: Yup.array().of(Yup.number()),
|
||||
display_model_names: Yup.array().of(Yup.string()),
|
||||
api_key_changed: Yup.boolean(),
|
||||
});
|
||||
|
||||
return (
|
||||
@ -122,6 +124,8 @@ export function LLMProviderUpdateForm({
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
setSubmitting(true);
|
||||
|
||||
values.api_key_changed = values.api_key !== initialValues.api_key;
|
||||
|
||||
// test the configuration
|
||||
if (!isEqual(values, initialValues)) {
|
||||
setIsTesting(true);
|
||||
|
Loading…
x
Reference in New Issue
Block a user