mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-29 03:01:59 +01:00
* sanitize llm keys and handle updates properly * fix llm provider testing * fix test * mypy * fix default model editing --------- Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app> Co-authored-by: Richard Kuo <rkuo@rkuo.com>
122 lines
4.3 KiB
Python
122 lines
4.3 KiB
Python
from typing import TYPE_CHECKING
|
|
|
|
from pydantic import BaseModel
|
|
from pydantic import Field
|
|
|
|
from onyx.llm.llm_provider_options import fetch_models_for_provider
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from onyx.db.models import LLMProvider as LLMProviderModel
|
|
|
|
|
|
class TestLLMRequest(BaseModel):
|
|
# provider level
|
|
name: str | None = None
|
|
provider: str
|
|
api_key: str | None = None
|
|
api_base: str | None = None
|
|
api_version: str | None = None
|
|
custom_config: dict[str, str] | None = None
|
|
|
|
# model level
|
|
default_model_name: str
|
|
fast_default_model_name: str | None = None
|
|
deployment_name: str | None = None
|
|
|
|
|
|
class LLMProviderDescriptor(BaseModel):
|
|
"""A descriptor for an LLM provider that can be safely viewed by
|
|
non-admin users. Used when giving a list of available LLMs."""
|
|
|
|
name: str
|
|
provider: str
|
|
model_names: list[str]
|
|
default_model_name: str
|
|
fast_default_model_name: str | None
|
|
is_default_provider: bool | None
|
|
is_default_vision_provider: bool | None
|
|
default_vision_model: str | None
|
|
display_model_names: list[str] | None
|
|
|
|
@classmethod
|
|
def from_model(
|
|
cls, llm_provider_model: "LLMProviderModel"
|
|
) -> "LLMProviderDescriptor":
|
|
return cls(
|
|
name=llm_provider_model.name,
|
|
provider=llm_provider_model.provider,
|
|
default_model_name=llm_provider_model.default_model_name,
|
|
fast_default_model_name=llm_provider_model.fast_default_model_name,
|
|
is_default_provider=llm_provider_model.is_default_provider,
|
|
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
|
|
default_vision_model=llm_provider_model.default_vision_model,
|
|
model_names=llm_provider_model.model_names
|
|
or fetch_models_for_provider(llm_provider_model.provider),
|
|
display_model_names=llm_provider_model.display_model_names,
|
|
)
|
|
|
|
|
|
class LLMProvider(BaseModel):
|
|
name: str
|
|
provider: str
|
|
api_key: str | None = None
|
|
api_base: str | None = None
|
|
api_version: str | None = None
|
|
custom_config: dict[str, str] | None = None
|
|
default_model_name: str
|
|
fast_default_model_name: str | None = None
|
|
is_public: bool = True
|
|
groups: list[int] = Field(default_factory=list)
|
|
display_model_names: list[str] | None = None
|
|
deployment_name: str | None = None
|
|
default_vision_model: str | None = None
|
|
|
|
|
|
class LLMProviderUpsertRequest(LLMProvider):
|
|
# should only be used for a "custom" provider
|
|
# for default providers, the built-in model names are used
|
|
model_names: list[str] | None = None
|
|
api_key_changed: bool = False
|
|
|
|
|
|
class LLMProviderView(LLMProvider):
|
|
"""Stripped down representation of LLMProvider for display / limited access info only"""
|
|
|
|
id: int
|
|
is_default_provider: bool | None = None
|
|
is_default_vision_provider: bool | None = None
|
|
model_names: list[str]
|
|
|
|
@classmethod
|
|
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "LLMProviderView":
|
|
return cls(
|
|
id=llm_provider_model.id,
|
|
name=llm_provider_model.name,
|
|
provider=llm_provider_model.provider,
|
|
api_key=llm_provider_model.api_key,
|
|
api_base=llm_provider_model.api_base,
|
|
api_version=llm_provider_model.api_version,
|
|
custom_config=llm_provider_model.custom_config,
|
|
default_model_name=llm_provider_model.default_model_name,
|
|
fast_default_model_name=llm_provider_model.fast_default_model_name,
|
|
is_default_provider=llm_provider_model.is_default_provider,
|
|
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
|
|
default_vision_model=llm_provider_model.default_vision_model,
|
|
display_model_names=llm_provider_model.display_model_names,
|
|
model_names=(
|
|
llm_provider_model.model_names
|
|
or fetch_models_for_provider(llm_provider_model.provider)
|
|
or [llm_provider_model.default_model_name]
|
|
),
|
|
is_public=llm_provider_model.is_public,
|
|
groups=[group.id for group in llm_provider_model.groups],
|
|
deployment_name=llm_provider_model.deployment_name,
|
|
)
|
|
|
|
|
|
class VisionProviderResponse(LLMProviderView):
|
|
"""Response model for vision providers endpoint, including vision-specific fields."""
|
|
|
|
vision_models: list[str]
|