rkuo-danswer 85ebadc8eb
sanitize llm keys and handle updates properly (#4270)
* sanitize llm keys and handle updates properly

* fix llm provider testing

* fix test

* mypy

* fix default model editing

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-03-20 01:13:02 +00:00

122 lines
4.3 KiB
Python

from typing import TYPE_CHECKING
from pydantic import BaseModel
from pydantic import Field
from onyx.llm.llm_provider_options import fetch_models_for_provider
if TYPE_CHECKING:
from onyx.db.models import LLMProvider as LLMProviderModel
class TestLLMRequest(BaseModel):
# provider level
name: str | None = None
provider: str
api_key: str | None = None
api_base: str | None = None
api_version: str | None = None
custom_config: dict[str, str] | None = None
# model level
default_model_name: str
fast_default_model_name: str | None = None
deployment_name: str | None = None
class LLMProviderDescriptor(BaseModel):
"""A descriptor for an LLM provider that can be safely viewed by
non-admin users. Used when giving a list of available LLMs."""
name: str
provider: str
model_names: list[str]
default_model_name: str
fast_default_model_name: str | None
is_default_provider: bool | None
is_default_vision_provider: bool | None
default_vision_model: str | None
display_model_names: list[str] | None
@classmethod
def from_model(
cls, llm_provider_model: "LLMProviderModel"
) -> "LLMProviderDescriptor":
return cls(
name=llm_provider_model.name,
provider=llm_provider_model.provider,
default_model_name=llm_provider_model.default_model_name,
fast_default_model_name=llm_provider_model.fast_default_model_name,
is_default_provider=llm_provider_model.is_default_provider,
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
default_vision_model=llm_provider_model.default_vision_model,
model_names=llm_provider_model.model_names
or fetch_models_for_provider(llm_provider_model.provider),
display_model_names=llm_provider_model.display_model_names,
)
class LLMProvider(BaseModel):
name: str
provider: str
api_key: str | None = None
api_base: str | None = None
api_version: str | None = None
custom_config: dict[str, str] | None = None
default_model_name: str
fast_default_model_name: str | None = None
is_public: bool = True
groups: list[int] = Field(default_factory=list)
display_model_names: list[str] | None = None
deployment_name: str | None = None
default_vision_model: str | None = None
class LLMProviderUpsertRequest(LLMProvider):
# should only be used for a "custom" provider
# for default providers, the built-in model names are used
model_names: list[str] | None = None
api_key_changed: bool = False
class LLMProviderView(LLMProvider):
"""Stripped down representation of LLMProvider for display / limited access info only"""
id: int
is_default_provider: bool | None = None
is_default_vision_provider: bool | None = None
model_names: list[str]
@classmethod
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "LLMProviderView":
return cls(
id=llm_provider_model.id,
name=llm_provider_model.name,
provider=llm_provider_model.provider,
api_key=llm_provider_model.api_key,
api_base=llm_provider_model.api_base,
api_version=llm_provider_model.api_version,
custom_config=llm_provider_model.custom_config,
default_model_name=llm_provider_model.default_model_name,
fast_default_model_name=llm_provider_model.fast_default_model_name,
is_default_provider=llm_provider_model.is_default_provider,
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
default_vision_model=llm_provider_model.default_vision_model,
display_model_names=llm_provider_model.display_model_names,
model_names=(
llm_provider_model.model_names
or fetch_models_for_provider(llm_provider_model.provider)
or [llm_provider_model.default_model_name]
),
is_public=llm_provider_model.is_public,
groups=[group.id for group in llm_provider_model.groups],
deployment_name=llm_provider_model.deployment_name,
)
class VisionProviderResponse(LLMProviderView):
"""Response model for vision providers endpoint, including vision-specific fields."""
vision_models: list[str]