mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-24 20:06:32 +02:00
sanitize llm keys and handle updates properly (#4270)
* sanitize llm keys and handle updates properly * fix llm provider testing * fix test * mypy * fix default model editing --------- Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app> Co-authored-by: Richard Kuo <rkuo@rkuo.com>
This commit is contained in:
@@ -271,6 +271,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
|||||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||||
model_names=ANTHROPIC_MODEL_NAMES,
|
model_names=ANTHROPIC_MODEL_NAMES,
|
||||||
display_model_names=["claude-3-5-sonnet-20241022"],
|
display_model_names=["claude-3-5-sonnet-20241022"],
|
||||||
|
api_key_changed=True,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
full_provider = upsert_llm_provider(anthropic_provider, db_session)
|
full_provider = upsert_llm_provider(anthropic_provider, db_session)
|
||||||
@@ -283,7 +284,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if OPENAI_DEFAULT_API_KEY:
|
if OPENAI_DEFAULT_API_KEY:
|
||||||
open_provider = LLMProviderUpsertRequest(
|
openai_provider = LLMProviderUpsertRequest(
|
||||||
name="OpenAI",
|
name="OpenAI",
|
||||||
provider=OPENAI_PROVIDER_NAME,
|
provider=OPENAI_PROVIDER_NAME,
|
||||||
api_key=OPENAI_DEFAULT_API_KEY,
|
api_key=OPENAI_DEFAULT_API_KEY,
|
||||||
@@ -291,9 +292,10 @@ def configure_default_api_keys(db_session: Session) -> None:
|
|||||||
fast_default_model_name="gpt-4o-mini",
|
fast_default_model_name="gpt-4o-mini",
|
||||||
model_names=OPEN_AI_MODEL_NAMES,
|
model_names=OPEN_AI_MODEL_NAMES,
|
||||||
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
|
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
|
||||||
|
api_key_changed=True,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
full_provider = upsert_llm_provider(open_provider, db_session)
|
full_provider = upsert_llm_provider(openai_provider, db_session)
|
||||||
update_default_provider(full_provider.id, db_session)
|
update_default_provider(full_provider.id, db_session)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to configure OpenAI provider: {e}")
|
logger.error(f"Failed to configure OpenAI provider: {e}")
|
||||||
|
@@ -16,8 +16,8 @@ from onyx.db.models import User__UserGroup
|
|||||||
from onyx.llm.utils import model_supports_image_input
|
from onyx.llm.utils import model_supports_image_input
|
||||||
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||||
from onyx.server.manage.llm.models import FullLLMProvider
|
|
||||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||||
|
from onyx.server.manage.llm.models import LLMProviderView
|
||||||
from shared_configs.enums import EmbeddingProvider
|
from shared_configs.enums import EmbeddingProvider
|
||||||
|
|
||||||
|
|
||||||
@@ -67,7 +67,7 @@ def upsert_cloud_embedding_provider(
|
|||||||
def upsert_llm_provider(
|
def upsert_llm_provider(
|
||||||
llm_provider: LLMProviderUpsertRequest,
|
llm_provider: LLMProviderUpsertRequest,
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
) -> FullLLMProvider:
|
) -> LLMProviderView:
|
||||||
existing_llm_provider = db_session.scalar(
|
existing_llm_provider = db_session.scalar(
|
||||||
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
|
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
|
||||||
)
|
)
|
||||||
@@ -98,7 +98,7 @@ def upsert_llm_provider(
|
|||||||
group_ids=llm_provider.groups,
|
group_ids=llm_provider.groups,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
)
|
)
|
||||||
full_llm_provider = FullLLMProvider.from_model(existing_llm_provider)
|
full_llm_provider = LLMProviderView.from_model(existing_llm_provider)
|
||||||
|
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
@@ -132,6 +132,16 @@ def fetch_existing_llm_providers(
|
|||||||
return list(db_session.scalars(stmt).all())
|
return list(db_session.scalars(stmt).all())
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_existing_llm_provider(
|
||||||
|
provider_name: str, db_session: Session
|
||||||
|
) -> LLMProviderModel | None:
|
||||||
|
provider_model = db_session.scalar(
|
||||||
|
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
|
||||||
|
)
|
||||||
|
|
||||||
|
return provider_model
|
||||||
|
|
||||||
|
|
||||||
def fetch_existing_llm_providers_for_user(
|
def fetch_existing_llm_providers_for_user(
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
user: User | None = None,
|
user: User | None = None,
|
||||||
@@ -177,7 +187,7 @@ def fetch_embedding_provider(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
|
def fetch_default_provider(db_session: Session) -> LLMProviderView | None:
|
||||||
provider_model = db_session.scalar(
|
provider_model = db_session.scalar(
|
||||||
select(LLMProviderModel).where(
|
select(LLMProviderModel).where(
|
||||||
LLMProviderModel.is_default_provider == True # noqa: E712
|
LLMProviderModel.is_default_provider == True # noqa: E712
|
||||||
@@ -185,10 +195,10 @@ def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
|
|||||||
)
|
)
|
||||||
if not provider_model:
|
if not provider_model:
|
||||||
return None
|
return None
|
||||||
return FullLLMProvider.from_model(provider_model)
|
return LLMProviderView.from_model(provider_model)
|
||||||
|
|
||||||
|
|
||||||
def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None:
|
def fetch_default_vision_provider(db_session: Session) -> LLMProviderView | None:
|
||||||
provider_model = db_session.scalar(
|
provider_model = db_session.scalar(
|
||||||
select(LLMProviderModel).where(
|
select(LLMProviderModel).where(
|
||||||
LLMProviderModel.is_default_vision_provider == True # noqa: E712
|
LLMProviderModel.is_default_vision_provider == True # noqa: E712
|
||||||
@@ -196,16 +206,18 @@ def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None
|
|||||||
)
|
)
|
||||||
if not provider_model:
|
if not provider_model:
|
||||||
return None
|
return None
|
||||||
return FullLLMProvider.from_model(provider_model)
|
return LLMProviderView.from_model(provider_model)
|
||||||
|
|
||||||
|
|
||||||
def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None:
|
def fetch_llm_provider_view(
|
||||||
|
db_session: Session, provider_name: str
|
||||||
|
) -> LLMProviderView | None:
|
||||||
provider_model = db_session.scalar(
|
provider_model = db_session.scalar(
|
||||||
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
|
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
|
||||||
)
|
)
|
||||||
if not provider_model:
|
if not provider_model:
|
||||||
return None
|
return None
|
||||||
return FullLLMProvider.from_model(provider_model)
|
return LLMProviderView.from_model(provider_model)
|
||||||
|
|
||||||
|
|
||||||
def remove_embedding_provider(
|
def remove_embedding_provider(
|
||||||
|
@@ -9,14 +9,14 @@ from onyx.db.engine import get_session_with_current_tenant
|
|||||||
from onyx.db.llm import fetch_default_provider
|
from onyx.db.llm import fetch_default_provider
|
||||||
from onyx.db.llm import fetch_default_vision_provider
|
from onyx.db.llm import fetch_default_vision_provider
|
||||||
from onyx.db.llm import fetch_existing_llm_providers
|
from onyx.db.llm import fetch_existing_llm_providers
|
||||||
from onyx.db.llm import fetch_provider
|
from onyx.db.llm import fetch_llm_provider_view
|
||||||
from onyx.db.models import Persona
|
from onyx.db.models import Persona
|
||||||
from onyx.llm.chat_llm import DefaultMultiLLM
|
from onyx.llm.chat_llm import DefaultMultiLLM
|
||||||
from onyx.llm.exceptions import GenAIDisabledException
|
from onyx.llm.exceptions import GenAIDisabledException
|
||||||
from onyx.llm.interfaces import LLM
|
from onyx.llm.interfaces import LLM
|
||||||
from onyx.llm.override_models import LLMOverride
|
from onyx.llm.override_models import LLMOverride
|
||||||
from onyx.llm.utils import model_supports_image_input
|
from onyx.llm.utils import model_supports_image_input
|
||||||
from onyx.server.manage.llm.models import FullLLMProvider
|
from onyx.server.manage.llm.models import LLMProviderView
|
||||||
from onyx.utils.headers import build_llm_extra_headers
|
from onyx.utils.headers import build_llm_extra_headers
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
from onyx.utils.long_term_log import LongTermLogger
|
from onyx.utils.long_term_log import LongTermLogger
|
||||||
@@ -62,7 +62,7 @@ def get_llms_for_persona(
|
|||||||
)
|
)
|
||||||
|
|
||||||
with get_session_context_manager() as db_session:
|
with get_session_context_manager() as db_session:
|
||||||
llm_provider = fetch_provider(db_session, provider_name)
|
llm_provider = fetch_llm_provider_view(db_session, provider_name)
|
||||||
|
|
||||||
if not llm_provider:
|
if not llm_provider:
|
||||||
raise ValueError("No LLM provider found")
|
raise ValueError("No LLM provider found")
|
||||||
@@ -106,7 +106,7 @@ def get_default_llm_with_vision(
|
|||||||
if DISABLE_GENERATIVE_AI:
|
if DISABLE_GENERATIVE_AI:
|
||||||
raise GenAIDisabledException()
|
raise GenAIDisabledException()
|
||||||
|
|
||||||
def create_vision_llm(provider: FullLLMProvider, model: str) -> LLM:
|
def create_vision_llm(provider: LLMProviderView, model: str) -> LLM:
|
||||||
"""Helper to create an LLM if the provider supports image input."""
|
"""Helper to create an LLM if the provider supports image input."""
|
||||||
return get_llm(
|
return get_llm(
|
||||||
provider=provider.provider,
|
provider=provider.provider,
|
||||||
@@ -148,7 +148,7 @@ def get_default_llm_with_vision(
|
|||||||
provider.default_vision_model, provider.provider
|
provider.default_vision_model, provider.provider
|
||||||
):
|
):
|
||||||
return create_vision_llm(
|
return create_vision_llm(
|
||||||
FullLLMProvider.from_model(provider), provider.default_vision_model
|
LLMProviderView.from_model(provider), provider.default_vision_model
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
@@ -9,9 +9,9 @@ from sqlalchemy.orm import Session
|
|||||||
from onyx.auth.users import current_admin_user
|
from onyx.auth.users import current_admin_user
|
||||||
from onyx.auth.users import current_chat_accessible_user
|
from onyx.auth.users import current_chat_accessible_user
|
||||||
from onyx.db.engine import get_session
|
from onyx.db.engine import get_session
|
||||||
|
from onyx.db.llm import fetch_existing_llm_provider
|
||||||
from onyx.db.llm import fetch_existing_llm_providers
|
from onyx.db.llm import fetch_existing_llm_providers
|
||||||
from onyx.db.llm import fetch_existing_llm_providers_for_user
|
from onyx.db.llm import fetch_existing_llm_providers_for_user
|
||||||
from onyx.db.llm import fetch_provider
|
|
||||||
from onyx.db.llm import remove_llm_provider
|
from onyx.db.llm import remove_llm_provider
|
||||||
from onyx.db.llm import update_default_provider
|
from onyx.db.llm import update_default_provider
|
||||||
from onyx.db.llm import update_default_vision_provider
|
from onyx.db.llm import update_default_vision_provider
|
||||||
@@ -24,9 +24,9 @@ from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
|
|||||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||||
from onyx.llm.utils import model_supports_image_input
|
from onyx.llm.utils import model_supports_image_input
|
||||||
from onyx.llm.utils import test_llm
|
from onyx.llm.utils import test_llm
|
||||||
from onyx.server.manage.llm.models import FullLLMProvider
|
|
||||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||||
|
from onyx.server.manage.llm.models import LLMProviderView
|
||||||
from onyx.server.manage.llm.models import TestLLMRequest
|
from onyx.server.manage.llm.models import TestLLMRequest
|
||||||
from onyx.server.manage.llm.models import VisionProviderResponse
|
from onyx.server.manage.llm.models import VisionProviderResponse
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
@@ -49,11 +49,27 @@ def fetch_llm_options(
|
|||||||
def test_llm_configuration(
|
def test_llm_configuration(
|
||||||
test_llm_request: TestLLMRequest,
|
test_llm_request: TestLLMRequest,
|
||||||
_: User | None = Depends(current_admin_user),
|
_: User | None = Depends(current_admin_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Test regular llm and fast llm settings"""
|
||||||
|
|
||||||
|
# the api key is sanitized if we are testing a provider already in the system
|
||||||
|
|
||||||
|
test_api_key = test_llm_request.api_key
|
||||||
|
if test_llm_request.name:
|
||||||
|
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
|
||||||
|
# as it turns out the name is not editable in the UI and other code also keys off name,
|
||||||
|
# so we won't rock the boat just yet.
|
||||||
|
existing_provider = fetch_existing_llm_provider(
|
||||||
|
test_llm_request.name, db_session
|
||||||
|
)
|
||||||
|
if existing_provider:
|
||||||
|
test_api_key = existing_provider.api_key
|
||||||
|
|
||||||
llm = get_llm(
|
llm = get_llm(
|
||||||
provider=test_llm_request.provider,
|
provider=test_llm_request.provider,
|
||||||
model=test_llm_request.default_model_name,
|
model=test_llm_request.default_model_name,
|
||||||
api_key=test_llm_request.api_key,
|
api_key=test_api_key,
|
||||||
api_base=test_llm_request.api_base,
|
api_base=test_llm_request.api_base,
|
||||||
api_version=test_llm_request.api_version,
|
api_version=test_llm_request.api_version,
|
||||||
custom_config=test_llm_request.custom_config,
|
custom_config=test_llm_request.custom_config,
|
||||||
@@ -69,7 +85,7 @@ def test_llm_configuration(
|
|||||||
fast_llm = get_llm(
|
fast_llm = get_llm(
|
||||||
provider=test_llm_request.provider,
|
provider=test_llm_request.provider,
|
||||||
model=test_llm_request.fast_default_model_name,
|
model=test_llm_request.fast_default_model_name,
|
||||||
api_key=test_llm_request.api_key,
|
api_key=test_api_key,
|
||||||
api_base=test_llm_request.api_base,
|
api_base=test_llm_request.api_base,
|
||||||
api_version=test_llm_request.api_version,
|
api_version=test_llm_request.api_version,
|
||||||
custom_config=test_llm_request.custom_config,
|
custom_config=test_llm_request.custom_config,
|
||||||
@@ -119,11 +135,17 @@ def test_default_provider(
|
|||||||
def list_llm_providers(
|
def list_llm_providers(
|
||||||
_: User | None = Depends(current_admin_user),
|
_: User | None = Depends(current_admin_user),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> list[FullLLMProvider]:
|
) -> list[LLMProviderView]:
|
||||||
return [
|
llm_provider_list: list[LLMProviderView] = []
|
||||||
FullLLMProvider.from_model(llm_provider_model)
|
for llm_provider_model in fetch_existing_llm_providers(db_session):
|
||||||
for llm_provider_model in fetch_existing_llm_providers(db_session)
|
full_llm_provider = LLMProviderView.from_model(llm_provider_model)
|
||||||
]
|
if full_llm_provider.api_key:
|
||||||
|
full_llm_provider.api_key = (
|
||||||
|
full_llm_provider.api_key[:4] + "****" + full_llm_provider.api_key[-4:]
|
||||||
|
)
|
||||||
|
llm_provider_list.append(full_llm_provider)
|
||||||
|
|
||||||
|
return llm_provider_list
|
||||||
|
|
||||||
|
|
||||||
@admin_router.put("/provider")
|
@admin_router.put("/provider")
|
||||||
@@ -135,11 +157,11 @@ def put_llm_provider(
|
|||||||
),
|
),
|
||||||
_: User | None = Depends(current_admin_user),
|
_: User | None = Depends(current_admin_user),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> FullLLMProvider:
|
) -> LLMProviderView:
|
||||||
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
|
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
|
||||||
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
|
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
|
||||||
# the result
|
# the result
|
||||||
existing_provider = fetch_provider(db_session, llm_provider.name)
|
existing_provider = fetch_existing_llm_provider(llm_provider.name, db_session)
|
||||||
if existing_provider and is_creation:
|
if existing_provider and is_creation:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
@@ -161,6 +183,11 @@ def put_llm_provider(
|
|||||||
llm_provider.fast_default_model_name
|
llm_provider.fast_default_model_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# the llm api key is sanitized when returned to clients, so the only time we
|
||||||
|
# should get a real key is when it is explicitly changed
|
||||||
|
if existing_provider and not llm_provider.api_key_changed:
|
||||||
|
llm_provider.api_key = existing_provider.api_key
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return upsert_llm_provider(
|
return upsert_llm_provider(
|
||||||
llm_provider=llm_provider,
|
llm_provider=llm_provider,
|
||||||
@@ -234,7 +261,7 @@ def get_vision_capable_providers(
|
|||||||
|
|
||||||
# Only include providers with at least one vision-capable model
|
# Only include providers with at least one vision-capable model
|
||||||
if vision_models:
|
if vision_models:
|
||||||
provider_dict = FullLLMProvider.from_model(provider).model_dump()
|
provider_dict = LLMProviderView.from_model(provider).model_dump()
|
||||||
provider_dict["vision_models"] = vision_models
|
provider_dict["vision_models"] = vision_models
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Vision provider: {provider.provider} with models: {vision_models}"
|
f"Vision provider: {provider.provider} with models: {vision_models}"
|
||||||
|
@@ -12,6 +12,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
class TestLLMRequest(BaseModel):
|
class TestLLMRequest(BaseModel):
|
||||||
# provider level
|
# provider level
|
||||||
|
name: str | None = None
|
||||||
provider: str
|
provider: str
|
||||||
api_key: str | None = None
|
api_key: str | None = None
|
||||||
api_base: str | None = None
|
api_base: str | None = None
|
||||||
@@ -76,16 +77,19 @@ class LLMProviderUpsertRequest(LLMProvider):
|
|||||||
# should only be used for a "custom" provider
|
# should only be used for a "custom" provider
|
||||||
# for default providers, the built-in model names are used
|
# for default providers, the built-in model names are used
|
||||||
model_names: list[str] | None = None
|
model_names: list[str] | None = None
|
||||||
|
api_key_changed: bool = False
|
||||||
|
|
||||||
|
|
||||||
class FullLLMProvider(LLMProvider):
|
class LLMProviderView(LLMProvider):
|
||||||
|
"""Stripped down representation of LLMProvider for display / limited access info only"""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
is_default_provider: bool | None = None
|
is_default_provider: bool | None = None
|
||||||
is_default_vision_provider: bool | None = None
|
is_default_vision_provider: bool | None = None
|
||||||
model_names: list[str]
|
model_names: list[str]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "FullLLMProvider":
|
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "LLMProviderView":
|
||||||
return cls(
|
return cls(
|
||||||
id=llm_provider_model.id,
|
id=llm_provider_model.id,
|
||||||
name=llm_provider_model.name,
|
name=llm_provider_model.name,
|
||||||
@@ -111,7 +115,7 @@ class FullLLMProvider(LLMProvider):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class VisionProviderResponse(FullLLMProvider):
|
class VisionProviderResponse(LLMProviderView):
|
||||||
"""Response model for vision providers endpoint, including vision-specific fields."""
|
"""Response model for vision providers endpoint, including vision-specific fields."""
|
||||||
|
|
||||||
vision_models: list[str]
|
vision_models: list[str]
|
||||||
|
@@ -307,6 +307,7 @@ def setup_postgres(db_session: Session) -> None:
|
|||||||
groups=[],
|
groups=[],
|
||||||
display_model_names=OPEN_AI_MODEL_NAMES,
|
display_model_names=OPEN_AI_MODEL_NAMES,
|
||||||
model_names=OPEN_AI_MODEL_NAMES,
|
model_names=OPEN_AI_MODEL_NAMES,
|
||||||
|
api_key_changed=True,
|
||||||
)
|
)
|
||||||
new_llm_provider = upsert_llm_provider(
|
new_llm_provider = upsert_llm_provider(
|
||||||
llm_provider=model_req, db_session=db_session
|
llm_provider=model_req, db_session=db_session
|
||||||
|
@@ -3,8 +3,8 @@ from uuid import uuid4
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from onyx.server.manage.llm.models import FullLLMProvider
|
|
||||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||||
|
from onyx.server.manage.llm.models import LLMProviderView
|
||||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||||
@@ -39,6 +39,7 @@ class LLMProviderManager:
|
|||||||
groups=groups or [],
|
groups=groups or [],
|
||||||
display_model_names=None,
|
display_model_names=None,
|
||||||
model_names=None,
|
model_names=None,
|
||||||
|
api_key_changed=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
llm_response = requests.put(
|
llm_response = requests.put(
|
||||||
@@ -90,7 +91,7 @@ class LLMProviderManager:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_all(
|
def get_all(
|
||||||
user_performing_action: DATestUser | None = None,
|
user_performing_action: DATestUser | None = None,
|
||||||
) -> list[FullLLMProvider]:
|
) -> list[LLMProviderView]:
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"{API_SERVER_URL}/admin/llm/provider",
|
f"{API_SERVER_URL}/admin/llm/provider",
|
||||||
headers=user_performing_action.headers
|
headers=user_performing_action.headers
|
||||||
@@ -98,7 +99,7 @@ class LLMProviderManager:
|
|||||||
else GENERAL_HEADERS,
|
else GENERAL_HEADERS,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return [FullLLMProvider(**ug) for ug in response.json()]
|
return [LLMProviderView(**ug) for ug in response.json()]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def verify(
|
def verify(
|
||||||
@@ -111,18 +112,19 @@ class LLMProviderManager:
|
|||||||
if llm_provider.id == fetched_llm_provider.id:
|
if llm_provider.id == fetched_llm_provider.id:
|
||||||
if verify_deleted:
|
if verify_deleted:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"User group {llm_provider.id} found but should be deleted"
|
f"LLM Provider {llm_provider.id} found but should be deleted"
|
||||||
)
|
)
|
||||||
fetched_llm_groups = set(fetched_llm_provider.groups)
|
fetched_llm_groups = set(fetched_llm_provider.groups)
|
||||||
llm_provider_groups = set(llm_provider.groups)
|
llm_provider_groups = set(llm_provider.groups)
|
||||||
|
|
||||||
|
# NOTE: returned api keys are sanitized and should not match
|
||||||
if (
|
if (
|
||||||
fetched_llm_groups == llm_provider_groups
|
fetched_llm_groups == llm_provider_groups
|
||||||
and llm_provider.provider == fetched_llm_provider.provider
|
and llm_provider.provider == fetched_llm_provider.provider
|
||||||
and llm_provider.api_key == fetched_llm_provider.api_key
|
|
||||||
and llm_provider.default_model_name
|
and llm_provider.default_model_name
|
||||||
== fetched_llm_provider.default_model_name
|
== fetched_llm_provider.default_model_name
|
||||||
and llm_provider.is_public == fetched_llm_provider.is_public
|
and llm_provider.is_public == fetched_llm_provider.is_public
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
if not verify_deleted:
|
if not verify_deleted:
|
||||||
raise ValueError(f"User group {llm_provider.id} not found")
|
raise ValueError(f"LLM Provider {llm_provider.id} not found")
|
||||||
|
@@ -34,6 +34,7 @@ def test_create_llm_provider_without_display_model_names(reset: None) -> None:
|
|||||||
json={
|
json={
|
||||||
"name": str(uuid.uuid4()),
|
"name": str(uuid.uuid4()),
|
||||||
"provider": "openai",
|
"provider": "openai",
|
||||||
|
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||||
"default_model_name": _DEFAULT_MODELS[0],
|
"default_model_name": _DEFAULT_MODELS[0],
|
||||||
"model_names": _DEFAULT_MODELS,
|
"model_names": _DEFAULT_MODELS,
|
||||||
"is_public": True,
|
"is_public": True,
|
||||||
@@ -49,6 +50,9 @@ def test_create_llm_provider_without_display_model_names(reset: None) -> None:
|
|||||||
assert provider_data["model_names"] == _DEFAULT_MODELS
|
assert provider_data["model_names"] == _DEFAULT_MODELS
|
||||||
assert provider_data["default_model_name"] == _DEFAULT_MODELS[0]
|
assert provider_data["default_model_name"] == _DEFAULT_MODELS[0]
|
||||||
assert provider_data["display_model_names"] is None
|
assert provider_data["display_model_names"] is None
|
||||||
|
assert (
|
||||||
|
provider_data["api_key"] == "sk-0****0000"
|
||||||
|
) # test that returned key is sanitized
|
||||||
|
|
||||||
|
|
||||||
def test_update_llm_provider_model_names(reset: None) -> None:
|
def test_update_llm_provider_model_names(reset: None) -> None:
|
||||||
@@ -64,10 +68,12 @@ def test_update_llm_provider_model_names(reset: None) -> None:
|
|||||||
json={
|
json={
|
||||||
"name": name,
|
"name": name,
|
||||||
"provider": "openai",
|
"provider": "openai",
|
||||||
|
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||||
"default_model_name": _DEFAULT_MODELS[0],
|
"default_model_name": _DEFAULT_MODELS[0],
|
||||||
"model_names": [_DEFAULT_MODELS[0]],
|
"model_names": [_DEFAULT_MODELS[0]],
|
||||||
"is_public": True,
|
"is_public": True,
|
||||||
"groups": [],
|
"groups": [],
|
||||||
|
"api_key_changed": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -81,6 +87,7 @@ def test_update_llm_provider_model_names(reset: None) -> None:
|
|||||||
"id": created_provider["id"],
|
"id": created_provider["id"],
|
||||||
"name": name,
|
"name": name,
|
||||||
"provider": created_provider["provider"],
|
"provider": created_provider["provider"],
|
||||||
|
"api_key": "sk-000000000000000000000000000000000000000000000001",
|
||||||
"default_model_name": _DEFAULT_MODELS[0],
|
"default_model_name": _DEFAULT_MODELS[0],
|
||||||
"model_names": _DEFAULT_MODELS,
|
"model_names": _DEFAULT_MODELS,
|
||||||
"is_public": True,
|
"is_public": True,
|
||||||
@@ -93,6 +100,30 @@ def test_update_llm_provider_model_names(reset: None) -> None:
|
|||||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||||
assert provider_data is not None
|
assert provider_data is not None
|
||||||
assert provider_data["model_names"] == _DEFAULT_MODELS
|
assert provider_data["model_names"] == _DEFAULT_MODELS
|
||||||
|
assert (
|
||||||
|
provider_data["api_key"] == "sk-0****0000"
|
||||||
|
) # test that key was NOT updated due to api_key_changed not being set
|
||||||
|
|
||||||
|
# Update with api_key_changed properly set
|
||||||
|
response = requests.put(
|
||||||
|
f"{API_SERVER_URL}/admin/llm/provider",
|
||||||
|
headers=admin_user.headers,
|
||||||
|
json={
|
||||||
|
"id": created_provider["id"],
|
||||||
|
"name": name,
|
||||||
|
"provider": created_provider["provider"],
|
||||||
|
"api_key": "sk-000000000000000000000000000000000000000000000001",
|
||||||
|
"default_model_name": _DEFAULT_MODELS[0],
|
||||||
|
"model_names": _DEFAULT_MODELS,
|
||||||
|
"is_public": True,
|
||||||
|
"groups": [],
|
||||||
|
"api_key_changed": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||||
|
assert provider_data is not None
|
||||||
|
assert provider_data["api_key"] == "sk-0****0001" # test that key was updated
|
||||||
|
|
||||||
|
|
||||||
def test_delete_llm_provider(reset: None) -> None:
|
def test_delete_llm_provider(reset: None) -> None:
|
||||||
@@ -107,6 +138,7 @@ def test_delete_llm_provider(reset: None) -> None:
|
|||||||
json={
|
json={
|
||||||
"name": "test-provider-delete",
|
"name": "test-provider-delete",
|
||||||
"provider": "openai",
|
"provider": "openai",
|
||||||
|
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||||
"default_model_name": _DEFAULT_MODELS[0],
|
"default_model_name": _DEFAULT_MODELS[0],
|
||||||
"model_names": _DEFAULT_MODELS,
|
"model_names": _DEFAULT_MODELS,
|
||||||
"is_public": True,
|
"is_public": True,
|
||||||
|
@@ -61,7 +61,7 @@ import {
|
|||||||
import { buildImgUrl } from "@/app/chat/files/images/utils";
|
import { buildImgUrl } from "@/app/chat/files/images/utils";
|
||||||
import { useAssistants } from "@/components/context/AssistantsContext";
|
import { useAssistants } from "@/components/context/AssistantsContext";
|
||||||
import { debounce } from "lodash";
|
import { debounce } from "lodash";
|
||||||
import { FullLLMProvider } from "../configuration/llm/interfaces";
|
import { LLMProviderView } from "../configuration/llm/interfaces";
|
||||||
import StarterMessagesList from "./StarterMessageList";
|
import StarterMessagesList from "./StarterMessageList";
|
||||||
|
|
||||||
import { Switch, SwitchField } from "@/components/ui/switch";
|
import { Switch, SwitchField } from "@/components/ui/switch";
|
||||||
@@ -123,7 +123,7 @@ export function AssistantEditor({
|
|||||||
documentSets: DocumentSet[];
|
documentSets: DocumentSet[];
|
||||||
user: User | null;
|
user: User | null;
|
||||||
defaultPublic: boolean;
|
defaultPublic: boolean;
|
||||||
llmProviders: FullLLMProvider[];
|
llmProviders: LLMProviderView[];
|
||||||
tools: ToolSnapshot[];
|
tools: ToolSnapshot[];
|
||||||
shouldAddAssistantToUserPreferences?: boolean;
|
shouldAddAssistantToUserPreferences?: boolean;
|
||||||
admin?: boolean;
|
admin?: boolean;
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
import { FullLLMProvider } from "../configuration/llm/interfaces";
|
import { LLMProviderView } from "../configuration/llm/interfaces";
|
||||||
import { Persona, StarterMessage } from "./interfaces";
|
import { Persona, StarterMessage } from "./interfaces";
|
||||||
|
|
||||||
interface PersonaUpsertRequest {
|
interface PersonaUpsertRequest {
|
||||||
@@ -319,7 +319,7 @@ export function checkPersonaRequiresImageGeneration(persona: Persona) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export function providersContainImageGeneratingSupport(
|
export function providersContainImageGeneratingSupport(
|
||||||
providers: FullLLMProvider[]
|
providers: LLMProviderView[]
|
||||||
) {
|
) {
|
||||||
return providers.some((provider) => provider.provider === "openai");
|
return providers.some((provider) => provider.provider === "openai");
|
||||||
}
|
}
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
|
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
|
||||||
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
|
import { LLMProviderView, WellKnownLLMProviderDescriptor } from "./interfaces";
|
||||||
import { Modal } from "@/components/Modal";
|
import { Modal } from "@/components/Modal";
|
||||||
import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm";
|
import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm";
|
||||||
import { CustomLLMProviderUpdateForm } from "./CustomLLMProviderUpdateForm";
|
import { CustomLLMProviderUpdateForm } from "./CustomLLMProviderUpdateForm";
|
||||||
@@ -19,7 +19,7 @@ function LLMProviderUpdateModal({
|
|||||||
}: {
|
}: {
|
||||||
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null | undefined;
|
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null | undefined;
|
||||||
onClose: () => void;
|
onClose: () => void;
|
||||||
existingLlmProvider?: FullLLMProvider;
|
existingLlmProvider?: LLMProviderView;
|
||||||
shouldMarkAsDefault?: boolean;
|
shouldMarkAsDefault?: boolean;
|
||||||
setPopup?: (popup: PopupSpec) => void;
|
setPopup?: (popup: PopupSpec) => void;
|
||||||
}) {
|
}) {
|
||||||
@@ -61,7 +61,7 @@ function LLMProviderDisplay({
|
|||||||
shouldMarkAsDefault,
|
shouldMarkAsDefault,
|
||||||
}: {
|
}: {
|
||||||
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null | undefined;
|
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null | undefined;
|
||||||
existingLlmProvider: FullLLMProvider;
|
existingLlmProvider: LLMProviderView;
|
||||||
shouldMarkAsDefault?: boolean;
|
shouldMarkAsDefault?: boolean;
|
||||||
}) {
|
}) {
|
||||||
const [formIsVisible, setFormIsVisible] = useState(false);
|
const [formIsVisible, setFormIsVisible] = useState(false);
|
||||||
@@ -146,7 +146,7 @@ export function ConfiguredLLMProviderDisplay({
|
|||||||
existingLlmProviders,
|
existingLlmProviders,
|
||||||
llmProviderDescriptors,
|
llmProviderDescriptors,
|
||||||
}: {
|
}: {
|
||||||
existingLlmProviders: FullLLMProvider[];
|
existingLlmProviders: LLMProviderView[];
|
||||||
llmProviderDescriptors: WellKnownLLMProviderDescriptor[];
|
llmProviderDescriptors: WellKnownLLMProviderDescriptor[];
|
||||||
}) {
|
}) {
|
||||||
existingLlmProviders = existingLlmProviders.sort((a, b) => {
|
existingLlmProviders = existingLlmProviders.sort((a, b) => {
|
||||||
|
@@ -21,7 +21,7 @@ import {
|
|||||||
} from "@/components/admin/connectors/Field";
|
} from "@/components/admin/connectors/Field";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
import { useSWRConfig } from "swr";
|
import { useSWRConfig } from "swr";
|
||||||
import { FullLLMProvider } from "./interfaces";
|
import { LLMProviderView } from "./interfaces";
|
||||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||||
import * as Yup from "yup";
|
import * as Yup from "yup";
|
||||||
import isEqual from "lodash/isEqual";
|
import isEqual from "lodash/isEqual";
|
||||||
@@ -43,7 +43,7 @@ export function CustomLLMProviderUpdateForm({
|
|||||||
hideSuccess,
|
hideSuccess,
|
||||||
}: {
|
}: {
|
||||||
onClose: () => void;
|
onClose: () => void;
|
||||||
existingLlmProvider?: FullLLMProvider;
|
existingLlmProvider?: LLMProviderView;
|
||||||
shouldMarkAsDefault?: boolean;
|
shouldMarkAsDefault?: boolean;
|
||||||
setPopup?: (popup: PopupSpec) => void;
|
setPopup?: (popup: PopupSpec) => void;
|
||||||
hideSuccess?: boolean;
|
hideSuccess?: boolean;
|
||||||
@@ -165,7 +165,7 @@ export function CustomLLMProviderUpdateForm({
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (shouldMarkAsDefault) {
|
if (shouldMarkAsDefault) {
|
||||||
const newLlmProvider = (await response.json()) as FullLLMProvider;
|
const newLlmProvider = (await response.json()) as LLMProviderView;
|
||||||
const setDefaultResponse = await fetch(
|
const setDefaultResponse = await fetch(
|
||||||
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
|
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
|
||||||
{
|
{
|
||||||
|
@@ -9,7 +9,7 @@ import Text from "@/components/ui/text";
|
|||||||
import Title from "@/components/ui/title";
|
import Title from "@/components/ui/title";
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import { ThreeDotsLoader } from "@/components/Loading";
|
import { ThreeDotsLoader } from "@/components/Loading";
|
||||||
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
|
import { LLMProviderView, WellKnownLLMProviderDescriptor } from "./interfaces";
|
||||||
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
|
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
|
||||||
import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm";
|
import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm";
|
||||||
import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
|
import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
|
||||||
@@ -25,7 +25,7 @@ function LLMProviderUpdateModal({
|
|||||||
}: {
|
}: {
|
||||||
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null;
|
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null;
|
||||||
onClose: () => void;
|
onClose: () => void;
|
||||||
existingLlmProvider?: FullLLMProvider;
|
existingLlmProvider?: LLMProviderView;
|
||||||
shouldMarkAsDefault?: boolean;
|
shouldMarkAsDefault?: boolean;
|
||||||
setPopup?: (popup: PopupSpec) => void;
|
setPopup?: (popup: PopupSpec) => void;
|
||||||
}) {
|
}) {
|
||||||
@@ -99,7 +99,7 @@ function DefaultLLMProviderDisplay({
|
|||||||
function AddCustomLLMProvider({
|
function AddCustomLLMProvider({
|
||||||
existingLlmProviders,
|
existingLlmProviders,
|
||||||
}: {
|
}: {
|
||||||
existingLlmProviders: FullLLMProvider[];
|
existingLlmProviders: LLMProviderView[];
|
||||||
}) {
|
}) {
|
||||||
const [formIsVisible, setFormIsVisible] = useState(false);
|
const [formIsVisible, setFormIsVisible] = useState(false);
|
||||||
|
|
||||||
@@ -130,7 +130,7 @@ export function LLMConfiguration() {
|
|||||||
const { data: llmProviderDescriptors } = useSWR<
|
const { data: llmProviderDescriptors } = useSWR<
|
||||||
WellKnownLLMProviderDescriptor[]
|
WellKnownLLMProviderDescriptor[]
|
||||||
>("/api/admin/llm/built-in/options", errorHandlingFetcher);
|
>("/api/admin/llm/built-in/options", errorHandlingFetcher);
|
||||||
const { data: existingLlmProviders } = useSWR<FullLLMProvider[]>(
|
const { data: existingLlmProviders } = useSWR<LLMProviderView[]>(
|
||||||
LLM_PROVIDERS_ADMIN_URL,
|
LLM_PROVIDERS_ADMIN_URL,
|
||||||
errorHandlingFetcher
|
errorHandlingFetcher
|
||||||
);
|
);
|
||||||
|
@@ -14,7 +14,7 @@ import {
|
|||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
import { useSWRConfig } from "swr";
|
import { useSWRConfig } from "swr";
|
||||||
import { defaultModelsByProvider, getDisplayNameForModel } from "@/lib/hooks";
|
import { defaultModelsByProvider, getDisplayNameForModel } from "@/lib/hooks";
|
||||||
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
|
import { LLMProviderView, WellKnownLLMProviderDescriptor } from "./interfaces";
|
||||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||||
import * as Yup from "yup";
|
import * as Yup from "yup";
|
||||||
import isEqual from "lodash/isEqual";
|
import isEqual from "lodash/isEqual";
|
||||||
@@ -31,7 +31,7 @@ export function LLMProviderUpdateForm({
|
|||||||
}: {
|
}: {
|
||||||
llmProviderDescriptor: WellKnownLLMProviderDescriptor;
|
llmProviderDescriptor: WellKnownLLMProviderDescriptor;
|
||||||
onClose: () => void;
|
onClose: () => void;
|
||||||
existingLlmProvider?: FullLLMProvider;
|
existingLlmProvider?: LLMProviderView;
|
||||||
shouldMarkAsDefault?: boolean;
|
shouldMarkAsDefault?: boolean;
|
||||||
hideAdvanced?: boolean;
|
hideAdvanced?: boolean;
|
||||||
setPopup?: (popup: PopupSpec) => void;
|
setPopup?: (popup: PopupSpec) => void;
|
||||||
@@ -73,6 +73,7 @@ export function LLMProviderUpdateForm({
|
|||||||
defaultModelsByProvider[llmProviderDescriptor.name] ||
|
defaultModelsByProvider[llmProviderDescriptor.name] ||
|
||||||
[],
|
[],
|
||||||
deployment_name: existingLlmProvider?.deployment_name,
|
deployment_name: existingLlmProvider?.deployment_name,
|
||||||
|
api_key_changed: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Setup validation schema if required
|
// Setup validation schema if required
|
||||||
@@ -113,6 +114,7 @@ export function LLMProviderUpdateForm({
|
|||||||
is_public: Yup.boolean().required(),
|
is_public: Yup.boolean().required(),
|
||||||
groups: Yup.array().of(Yup.number()),
|
groups: Yup.array().of(Yup.number()),
|
||||||
display_model_names: Yup.array().of(Yup.string()),
|
display_model_names: Yup.array().of(Yup.string()),
|
||||||
|
api_key_changed: Yup.boolean(),
|
||||||
});
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -122,6 +124,8 @@ export function LLMProviderUpdateForm({
|
|||||||
onSubmit={async (values, { setSubmitting }) => {
|
onSubmit={async (values, { setSubmitting }) => {
|
||||||
setSubmitting(true);
|
setSubmitting(true);
|
||||||
|
|
||||||
|
values.api_key_changed = values.api_key !== initialValues.api_key;
|
||||||
|
|
||||||
// test the configuration
|
// test the configuration
|
||||||
if (!isEqual(values, initialValues)) {
|
if (!isEqual(values, initialValues)) {
|
||||||
setIsTesting(true);
|
setIsTesting(true);
|
||||||
@@ -180,7 +184,7 @@ export function LLMProviderUpdateForm({
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (shouldMarkAsDefault) {
|
if (shouldMarkAsDefault) {
|
||||||
const newLlmProvider = (await response.json()) as FullLLMProvider;
|
const newLlmProvider = (await response.json()) as LLMProviderView;
|
||||||
const setDefaultResponse = await fetch(
|
const setDefaultResponse = await fetch(
|
||||||
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
|
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
|
||||||
{
|
{
|
||||||
|
@@ -53,14 +53,14 @@ export interface LLMProvider {
|
|||||||
is_default_vision_provider: boolean | null;
|
is_default_vision_provider: boolean | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface FullLLMProvider extends LLMProvider {
|
export interface LLMProviderView extends LLMProvider {
|
||||||
id: number;
|
id: number;
|
||||||
is_default_provider: boolean | null;
|
is_default_provider: boolean | null;
|
||||||
model_names: string[];
|
model_names: string[];
|
||||||
icon?: React.FC<{ size?: number; className?: string }>;
|
icon?: React.FC<{ size?: number; className?: string }>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface VisionProvider extends FullLLMProvider {
|
export interface VisionProvider extends LLMProviderView {
|
||||||
vision_models: string[];
|
vision_models: string[];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
import {
|
import {
|
||||||
FullLLMProvider,
|
LLMProviderView,
|
||||||
WellKnownLLMProviderDescriptor,
|
WellKnownLLMProviderDescriptor,
|
||||||
} from "@/app/admin/configuration/llm/interfaces";
|
} from "@/app/admin/configuration/llm/interfaces";
|
||||||
import { User } from "@/lib/types";
|
import { User } from "@/lib/types";
|
||||||
@@ -36,7 +36,7 @@ export async function checkLlmProvider(user: User | null) {
|
|||||||
const [providerResponse, optionsResponse, defaultCheckResponse] =
|
const [providerResponse, optionsResponse, defaultCheckResponse] =
|
||||||
await Promise.all(tasks);
|
await Promise.all(tasks);
|
||||||
|
|
||||||
let providers: FullLLMProvider[] = [];
|
let providers: LLMProviderView[] = [];
|
||||||
if (providerResponse?.ok) {
|
if (providerResponse?.ok) {
|
||||||
providers = await providerResponse.json();
|
providers = await providerResponse.json();
|
||||||
}
|
}
|
||||||
|
@@ -3,7 +3,7 @@ import { CCPairBasicInfo, DocumentSet, User } from "../types";
|
|||||||
import { getCurrentUserSS } from "../userSS";
|
import { getCurrentUserSS } from "../userSS";
|
||||||
import { fetchSS } from "../utilsSS";
|
import { fetchSS } from "../utilsSS";
|
||||||
import {
|
import {
|
||||||
FullLLMProvider,
|
LLMProviderView,
|
||||||
getProviderIcon,
|
getProviderIcon,
|
||||||
} from "@/app/admin/configuration/llm/interfaces";
|
} from "@/app/admin/configuration/llm/interfaces";
|
||||||
import { ToolSnapshot } from "../tools/interfaces";
|
import { ToolSnapshot } from "../tools/interfaces";
|
||||||
@@ -16,7 +16,7 @@ export async function fetchAssistantEditorInfoSS(
|
|||||||
{
|
{
|
||||||
ccPairs: CCPairBasicInfo[];
|
ccPairs: CCPairBasicInfo[];
|
||||||
documentSets: DocumentSet[];
|
documentSets: DocumentSet[];
|
||||||
llmProviders: FullLLMProvider[];
|
llmProviders: LLMProviderView[];
|
||||||
user: User | null;
|
user: User | null;
|
||||||
existingPersona: Persona | null;
|
existingPersona: Persona | null;
|
||||||
tools: ToolSnapshot[];
|
tools: ToolSnapshot[];
|
||||||
@@ -83,7 +83,7 @@ export async function fetchAssistantEditorInfoSS(
|
|||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
const llmProviders = (await llmProvidersResponse.json()) as FullLLMProvider[];
|
const llmProviders = (await llmProvidersResponse.json()) as LLMProviderView[];
|
||||||
|
|
||||||
if (personaId && personaResponse && !personaResponse.ok) {
|
if (personaId && personaResponse && !personaResponse.ok) {
|
||||||
return [null, `Failed to fetch Persona - ${await personaResponse.text()}`];
|
return [null, `Failed to fetch Persona - ${await personaResponse.text()}`];
|
||||||
|
Reference in New Issue
Block a user