mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 01:31:51 +01:00
sanitize llm keys and handle updates properly (#4270)
* sanitize llm keys and handle updates properly * fix llm provider testing * fix test * mypy * fix default model editing --------- Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app> Co-authored-by: Richard Kuo <rkuo@rkuo.com>
This commit is contained in:
parent
5dda53eec3
commit
85ebadc8eb
@ -271,6 +271,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_names=ANTHROPIC_MODEL_NAMES,
|
||||
display_model_names=["claude-3-5-sonnet-20241022"],
|
||||
api_key_changed=True,
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(anthropic_provider, db_session)
|
||||
@ -283,7 +284,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
)
|
||||
|
||||
if OPENAI_DEFAULT_API_KEY:
|
||||
open_provider = LLMProviderUpsertRequest(
|
||||
openai_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
@ -291,9 +292,10 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
fast_default_model_name="gpt-4o-mini",
|
||||
model_names=OPEN_AI_MODEL_NAMES,
|
||||
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
|
||||
api_key_changed=True,
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(open_provider, db_session)
|
||||
full_provider = upsert_llm_provider(openai_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure OpenAI provider: {e}")
|
||||
|
@ -16,8 +16,8 @@ from onyx.db.models import User__UserGroup
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import FullLLMProvider
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
|
||||
@ -67,7 +67,7 @@ def upsert_cloud_embedding_provider(
|
||||
def upsert_llm_provider(
|
||||
llm_provider: LLMProviderUpsertRequest,
|
||||
db_session: Session,
|
||||
) -> FullLLMProvider:
|
||||
) -> LLMProviderView:
|
||||
existing_llm_provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
|
||||
)
|
||||
@ -98,7 +98,7 @@ def upsert_llm_provider(
|
||||
group_ids=llm_provider.groups,
|
||||
db_session=db_session,
|
||||
)
|
||||
full_llm_provider = FullLLMProvider.from_model(existing_llm_provider)
|
||||
full_llm_provider = LLMProviderView.from_model(existing_llm_provider)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@ -132,6 +132,16 @@ def fetch_existing_llm_providers(
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def fetch_existing_llm_provider(
|
||||
provider_name: str, db_session: Session
|
||||
) -> LLMProviderModel | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
|
||||
)
|
||||
|
||||
return provider_model
|
||||
|
||||
|
||||
def fetch_existing_llm_providers_for_user(
|
||||
db_session: Session,
|
||||
user: User | None = None,
|
||||
@ -177,7 +187,7 @@ def fetch_embedding_provider(
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
|
||||
def fetch_default_provider(db_session: Session) -> LLMProviderView | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.is_default_provider == True # noqa: E712
|
||||
@ -185,10 +195,10 @@ def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
|
||||
)
|
||||
if not provider_model:
|
||||
return None
|
||||
return FullLLMProvider.from_model(provider_model)
|
||||
return LLMProviderView.from_model(provider_model)
|
||||
|
||||
|
||||
def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None:
|
||||
def fetch_default_vision_provider(db_session: Session) -> LLMProviderView | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.is_default_vision_provider == True # noqa: E712
|
||||
@ -196,16 +206,18 @@ def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None
|
||||
)
|
||||
if not provider_model:
|
||||
return None
|
||||
return FullLLMProvider.from_model(provider_model)
|
||||
return LLMProviderView.from_model(provider_model)
|
||||
|
||||
|
||||
def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None:
|
||||
def fetch_llm_provider_view(
|
||||
db_session: Session, provider_name: str
|
||||
) -> LLMProviderView | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
|
||||
)
|
||||
if not provider_model:
|
||||
return None
|
||||
return FullLLMProvider.from_model(provider_model)
|
||||
return LLMProviderView.from_model(provider_model)
|
||||
|
||||
|
||||
def remove_embedding_provider(
|
||||
|
@ -9,14 +9,14 @@ from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.llm import fetch_default_provider
|
||||
from onyx.db.llm import fetch_default_vision_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_provider
|
||||
from onyx.db.llm import fetch_llm_provider_view
|
||||
from onyx.db.models import Persona
|
||||
from onyx.llm.chat_llm import DefaultMultiLLM
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.server.manage.llm.models import FullLLMProvider
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.utils.headers import build_llm_extra_headers
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
@ -62,7 +62,7 @@ def get_llms_for_persona(
|
||||
)
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
llm_provider = fetch_provider(db_session, provider_name)
|
||||
llm_provider = fetch_llm_provider_view(db_session, provider_name)
|
||||
|
||||
if not llm_provider:
|
||||
raise ValueError("No LLM provider found")
|
||||
@ -106,7 +106,7 @@ def get_default_llm_with_vision(
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
raise GenAIDisabledException()
|
||||
|
||||
def create_vision_llm(provider: FullLLMProvider, model: str) -> LLM:
|
||||
def create_vision_llm(provider: LLMProviderView, model: str) -> LLM:
|
||||
"""Helper to create an LLM if the provider supports image input."""
|
||||
return get_llm(
|
||||
provider=provider.provider,
|
||||
@ -148,7 +148,7 @@ def get_default_llm_with_vision(
|
||||
provider.default_vision_model, provider.provider
|
||||
):
|
||||
return create_vision_llm(
|
||||
FullLLMProvider.from_model(provider), provider.default_vision_model
|
||||
LLMProviderView.from_model(provider), provider.default_vision_model
|
||||
)
|
||||
|
||||
return None
|
||||
|
@ -9,9 +9,9 @@ from sqlalchemy.orm import Session
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_existing_llm_providers_for_user
|
||||
from onyx.db.llm import fetch_provider
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import update_default_vision_provider
|
||||
@ -24,9 +24,9 @@ from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.llm.utils import test_llm
|
||||
from onyx.server.manage.llm.models import FullLLMProvider
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import TestLLMRequest
|
||||
from onyx.server.manage.llm.models import VisionProviderResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
@ -49,11 +49,27 @@ def fetch_llm_options(
|
||||
def test_llm_configuration(
|
||||
test_llm_request: TestLLMRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
"""Test regular llm and fast llm settings"""
|
||||
|
||||
# the api key is sanitized if we are testing a provider already in the system
|
||||
|
||||
test_api_key = test_llm_request.api_key
|
||||
if test_llm_request.name:
|
||||
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
|
||||
# as it turns out the name is not editable in the UI and other code also keys off name,
|
||||
# so we won't rock the boat just yet.
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
test_llm_request.name, db_session
|
||||
)
|
||||
if existing_provider:
|
||||
test_api_key = existing_provider.api_key
|
||||
|
||||
llm = get_llm(
|
||||
provider=test_llm_request.provider,
|
||||
model=test_llm_request.default_model_name,
|
||||
api_key=test_llm_request.api_key,
|
||||
api_key=test_api_key,
|
||||
api_base=test_llm_request.api_base,
|
||||
api_version=test_llm_request.api_version,
|
||||
custom_config=test_llm_request.custom_config,
|
||||
@ -69,7 +85,7 @@ def test_llm_configuration(
|
||||
fast_llm = get_llm(
|
||||
provider=test_llm_request.provider,
|
||||
model=test_llm_request.fast_default_model_name,
|
||||
api_key=test_llm_request.api_key,
|
||||
api_key=test_api_key,
|
||||
api_base=test_llm_request.api_base,
|
||||
api_version=test_llm_request.api_version,
|
||||
custom_config=test_llm_request.custom_config,
|
||||
@ -119,11 +135,17 @@ def test_default_provider(
|
||||
def list_llm_providers(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[FullLLMProvider]:
|
||||
return [
|
||||
FullLLMProvider.from_model(llm_provider_model)
|
||||
for llm_provider_model in fetch_existing_llm_providers(db_session)
|
||||
]
|
||||
) -> list[LLMProviderView]:
|
||||
llm_provider_list: list[LLMProviderView] = []
|
||||
for llm_provider_model in fetch_existing_llm_providers(db_session):
|
||||
full_llm_provider = LLMProviderView.from_model(llm_provider_model)
|
||||
if full_llm_provider.api_key:
|
||||
full_llm_provider.api_key = (
|
||||
full_llm_provider.api_key[:4] + "****" + full_llm_provider.api_key[-4:]
|
||||
)
|
||||
llm_provider_list.append(full_llm_provider)
|
||||
|
||||
return llm_provider_list
|
||||
|
||||
|
||||
@admin_router.put("/provider")
|
||||
@ -135,11 +157,11 @@ def put_llm_provider(
|
||||
),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FullLLMProvider:
|
||||
) -> LLMProviderView:
|
||||
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
|
||||
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
|
||||
# the result
|
||||
existing_provider = fetch_provider(db_session, llm_provider.name)
|
||||
existing_provider = fetch_existing_llm_provider(llm_provider.name, db_session)
|
||||
if existing_provider and is_creation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@ -161,6 +183,11 @@ def put_llm_provider(
|
||||
llm_provider.fast_default_model_name
|
||||
)
|
||||
|
||||
# the llm api key is sanitized when returned to clients, so the only time we
|
||||
# should get a real key is when it is explicitly changed
|
||||
if existing_provider and not llm_provider.api_key_changed:
|
||||
llm_provider.api_key = existing_provider.api_key
|
||||
|
||||
try:
|
||||
return upsert_llm_provider(
|
||||
llm_provider=llm_provider,
|
||||
@ -234,7 +261,7 @@ def get_vision_capable_providers(
|
||||
|
||||
# Only include providers with at least one vision-capable model
|
||||
if vision_models:
|
||||
provider_dict = FullLLMProvider.from_model(provider).model_dump()
|
||||
provider_dict = LLMProviderView.from_model(provider).model_dump()
|
||||
provider_dict["vision_models"] = vision_models
|
||||
logger.info(
|
||||
f"Vision provider: {provider.provider} with models: {vision_models}"
|
||||
|
@ -12,6 +12,7 @@ if TYPE_CHECKING:
|
||||
|
||||
class TestLLMRequest(BaseModel):
|
||||
# provider level
|
||||
name: str | None = None
|
||||
provider: str
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
@ -76,16 +77,19 @@ class LLMProviderUpsertRequest(LLMProvider):
|
||||
# should only be used for a "custom" provider
|
||||
# for default providers, the built-in model names are used
|
||||
model_names: list[str] | None = None
|
||||
api_key_changed: bool = False
|
||||
|
||||
|
||||
class FullLLMProvider(LLMProvider):
|
||||
class LLMProviderView(LLMProvider):
|
||||
"""Stripped down representation of LLMProvider for display / limited access info only"""
|
||||
|
||||
id: int
|
||||
is_default_provider: bool | None = None
|
||||
is_default_vision_provider: bool | None = None
|
||||
model_names: list[str]
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "FullLLMProvider":
|
||||
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "LLMProviderView":
|
||||
return cls(
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
@ -111,7 +115,7 @@ class FullLLMProvider(LLMProvider):
|
||||
)
|
||||
|
||||
|
||||
class VisionProviderResponse(FullLLMProvider):
|
||||
class VisionProviderResponse(LLMProviderView):
|
||||
"""Response model for vision providers endpoint, including vision-specific fields."""
|
||||
|
||||
vision_models: list[str]
|
||||
|
@ -307,6 +307,7 @@ def setup_postgres(db_session: Session) -> None:
|
||||
groups=[],
|
||||
display_model_names=OPEN_AI_MODEL_NAMES,
|
||||
model_names=OPEN_AI_MODEL_NAMES,
|
||||
api_key_changed=True,
|
||||
)
|
||||
new_llm_provider = upsert_llm_provider(
|
||||
llm_provider=model_req, db_session=db_session
|
||||
|
@ -3,8 +3,8 @@ from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.server.manage.llm.models import FullLLMProvider
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
@ -39,6 +39,7 @@ class LLMProviderManager:
|
||||
groups=groups or [],
|
||||
display_model_names=None,
|
||||
model_names=None,
|
||||
api_key_changed=True,
|
||||
)
|
||||
|
||||
llm_response = requests.put(
|
||||
@ -90,7 +91,7 @@ class LLMProviderManager:
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> list[FullLLMProvider]:
|
||||
) -> list[LLMProviderView]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=user_performing_action.headers
|
||||
@ -98,7 +99,7 @@ class LLMProviderManager:
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [FullLLMProvider(**ug) for ug in response.json()]
|
||||
return [LLMProviderView(**ug) for ug in response.json()]
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
@ -111,18 +112,19 @@ class LLMProviderManager:
|
||||
if llm_provider.id == fetched_llm_provider.id:
|
||||
if verify_deleted:
|
||||
raise ValueError(
|
||||
f"User group {llm_provider.id} found but should be deleted"
|
||||
f"LLM Provider {llm_provider.id} found but should be deleted"
|
||||
)
|
||||
fetched_llm_groups = set(fetched_llm_provider.groups)
|
||||
llm_provider_groups = set(llm_provider.groups)
|
||||
|
||||
# NOTE: returned api keys are sanitized and should not match
|
||||
if (
|
||||
fetched_llm_groups == llm_provider_groups
|
||||
and llm_provider.provider == fetched_llm_provider.provider
|
||||
and llm_provider.api_key == fetched_llm_provider.api_key
|
||||
and llm_provider.default_model_name
|
||||
== fetched_llm_provider.default_model_name
|
||||
and llm_provider.is_public == fetched_llm_provider.is_public
|
||||
):
|
||||
return
|
||||
if not verify_deleted:
|
||||
raise ValueError(f"User group {llm_provider.id} not found")
|
||||
raise ValueError(f"LLM Provider {llm_provider.id} not found")
|
||||
|
@ -34,6 +34,7 @@ def test_create_llm_provider_without_display_model_names(reset: None) -> None:
|
||||
json={
|
||||
"name": str(uuid.uuid4()),
|
||||
"provider": "openai",
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": _DEFAULT_MODELS,
|
||||
"is_public": True,
|
||||
@ -49,6 +50,9 @@ def test_create_llm_provider_without_display_model_names(reset: None) -> None:
|
||||
assert provider_data["model_names"] == _DEFAULT_MODELS
|
||||
assert provider_data["default_model_name"] == _DEFAULT_MODELS[0]
|
||||
assert provider_data["display_model_names"] is None
|
||||
assert (
|
||||
provider_data["api_key"] == "sk-0****0000"
|
||||
) # test that returned key is sanitized
|
||||
|
||||
|
||||
def test_update_llm_provider_model_names(reset: None) -> None:
|
||||
@ -64,10 +68,12 @@ def test_update_llm_provider_model_names(reset: None) -> None:
|
||||
json={
|
||||
"name": name,
|
||||
"provider": "openai",
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": [_DEFAULT_MODELS[0]],
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
"api_key_changed": True,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@ -81,6 +87,7 @@ def test_update_llm_provider_model_names(reset: None) -> None:
|
||||
"id": created_provider["id"],
|
||||
"name": name,
|
||||
"provider": created_provider["provider"],
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000001",
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": _DEFAULT_MODELS,
|
||||
"is_public": True,
|
||||
@ -93,6 +100,30 @@ def test_update_llm_provider_model_names(reset: None) -> None:
|
||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||
assert provider_data is not None
|
||||
assert provider_data["model_names"] == _DEFAULT_MODELS
|
||||
assert (
|
||||
provider_data["api_key"] == "sk-0****0000"
|
||||
) # test that key was NOT updated due to api_key_changed not being set
|
||||
|
||||
# Update with api_key_changed properly set
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"id": created_provider["id"],
|
||||
"name": name,
|
||||
"provider": created_provider["provider"],
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000001",
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": _DEFAULT_MODELS,
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
"api_key_changed": True,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||
assert provider_data is not None
|
||||
assert provider_data["api_key"] == "sk-0****0001" # test that key was updated
|
||||
|
||||
|
||||
def test_delete_llm_provider(reset: None) -> None:
|
||||
@ -107,6 +138,7 @@ def test_delete_llm_provider(reset: None) -> None:
|
||||
json={
|
||||
"name": "test-provider-delete",
|
||||
"provider": "openai",
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": _DEFAULT_MODELS,
|
||||
"is_public": True,
|
||||
|
@ -61,7 +61,7 @@ import {
|
||||
import { buildImgUrl } from "@/app/chat/files/images/utils";
|
||||
import { useAssistants } from "@/components/context/AssistantsContext";
|
||||
import { debounce } from "lodash";
|
||||
import { FullLLMProvider } from "../configuration/llm/interfaces";
|
||||
import { LLMProviderView } from "../configuration/llm/interfaces";
|
||||
import StarterMessagesList from "./StarterMessageList";
|
||||
|
||||
import { Switch, SwitchField } from "@/components/ui/switch";
|
||||
@ -123,7 +123,7 @@ export function AssistantEditor({
|
||||
documentSets: DocumentSet[];
|
||||
user: User | null;
|
||||
defaultPublic: boolean;
|
||||
llmProviders: FullLLMProvider[];
|
||||
llmProviders: LLMProviderView[];
|
||||
tools: ToolSnapshot[];
|
||||
shouldAddAssistantToUserPreferences?: boolean;
|
||||
admin?: boolean;
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { FullLLMProvider } from "../configuration/llm/interfaces";
|
||||
import { LLMProviderView } from "../configuration/llm/interfaces";
|
||||
import { Persona, StarterMessage } from "./interfaces";
|
||||
|
||||
interface PersonaUpsertRequest {
|
||||
@ -319,7 +319,7 @@ export function checkPersonaRequiresImageGeneration(persona: Persona) {
|
||||
}
|
||||
|
||||
export function providersContainImageGeneratingSupport(
|
||||
providers: FullLLMProvider[]
|
||||
providers: LLMProviderView[]
|
||||
) {
|
||||
return providers.some((provider) => provider.provider === "openai");
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
|
||||
import { LLMProviderView, WellKnownLLMProviderDescriptor } from "./interfaces";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm";
|
||||
import { CustomLLMProviderUpdateForm } from "./CustomLLMProviderUpdateForm";
|
||||
@ -19,7 +19,7 @@ function LLMProviderUpdateModal({
|
||||
}: {
|
||||
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null | undefined;
|
||||
onClose: () => void;
|
||||
existingLlmProvider?: FullLLMProvider;
|
||||
existingLlmProvider?: LLMProviderView;
|
||||
shouldMarkAsDefault?: boolean;
|
||||
setPopup?: (popup: PopupSpec) => void;
|
||||
}) {
|
||||
@ -61,7 +61,7 @@ function LLMProviderDisplay({
|
||||
shouldMarkAsDefault,
|
||||
}: {
|
||||
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null | undefined;
|
||||
existingLlmProvider: FullLLMProvider;
|
||||
existingLlmProvider: LLMProviderView;
|
||||
shouldMarkAsDefault?: boolean;
|
||||
}) {
|
||||
const [formIsVisible, setFormIsVisible] = useState(false);
|
||||
@ -146,7 +146,7 @@ export function ConfiguredLLMProviderDisplay({
|
||||
existingLlmProviders,
|
||||
llmProviderDescriptors,
|
||||
}: {
|
||||
existingLlmProviders: FullLLMProvider[];
|
||||
existingLlmProviders: LLMProviderView[];
|
||||
llmProviderDescriptors: WellKnownLLMProviderDescriptor[];
|
||||
}) {
|
||||
existingLlmProviders = existingLlmProviders.sort((a, b) => {
|
||||
|
@ -21,7 +21,7 @@ import {
|
||||
} from "@/components/admin/connectors/Field";
|
||||
import { useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { FullLLMProvider } from "./interfaces";
|
||||
import { LLMProviderView } from "./interfaces";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import * as Yup from "yup";
|
||||
import isEqual from "lodash/isEqual";
|
||||
@ -43,7 +43,7 @@ export function CustomLLMProviderUpdateForm({
|
||||
hideSuccess,
|
||||
}: {
|
||||
onClose: () => void;
|
||||
existingLlmProvider?: FullLLMProvider;
|
||||
existingLlmProvider?: LLMProviderView;
|
||||
shouldMarkAsDefault?: boolean;
|
||||
setPopup?: (popup: PopupSpec) => void;
|
||||
hideSuccess?: boolean;
|
||||
@ -165,7 +165,7 @@ export function CustomLLMProviderUpdateForm({
|
||||
}
|
||||
|
||||
if (shouldMarkAsDefault) {
|
||||
const newLlmProvider = (await response.json()) as FullLLMProvider;
|
||||
const newLlmProvider = (await response.json()) as LLMProviderView;
|
||||
const setDefaultResponse = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
|
||||
{
|
||||
|
@ -9,7 +9,7 @@ import Text from "@/components/ui/text";
|
||||
import Title from "@/components/ui/title";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
|
||||
import { LLMProviderView, WellKnownLLMProviderDescriptor } from "./interfaces";
|
||||
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
|
||||
@ -25,7 +25,7 @@ function LLMProviderUpdateModal({
|
||||
}: {
|
||||
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null;
|
||||
onClose: () => void;
|
||||
existingLlmProvider?: FullLLMProvider;
|
||||
existingLlmProvider?: LLMProviderView;
|
||||
shouldMarkAsDefault?: boolean;
|
||||
setPopup?: (popup: PopupSpec) => void;
|
||||
}) {
|
||||
@ -99,7 +99,7 @@ function DefaultLLMProviderDisplay({
|
||||
function AddCustomLLMProvider({
|
||||
existingLlmProviders,
|
||||
}: {
|
||||
existingLlmProviders: FullLLMProvider[];
|
||||
existingLlmProviders: LLMProviderView[];
|
||||
}) {
|
||||
const [formIsVisible, setFormIsVisible] = useState(false);
|
||||
|
||||
@ -130,7 +130,7 @@ export function LLMConfiguration() {
|
||||
const { data: llmProviderDescriptors } = useSWR<
|
||||
WellKnownLLMProviderDescriptor[]
|
||||
>("/api/admin/llm/built-in/options", errorHandlingFetcher);
|
||||
const { data: existingLlmProviders } = useSWR<FullLLMProvider[]>(
|
||||
const { data: existingLlmProviders } = useSWR<LLMProviderView[]>(
|
||||
LLM_PROVIDERS_ADMIN_URL,
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
@ -14,7 +14,7 @@ import {
|
||||
import { useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { defaultModelsByProvider, getDisplayNameForModel } from "@/lib/hooks";
|
||||
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
|
||||
import { LLMProviderView, WellKnownLLMProviderDescriptor } from "./interfaces";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import * as Yup from "yup";
|
||||
import isEqual from "lodash/isEqual";
|
||||
@ -31,7 +31,7 @@ export function LLMProviderUpdateForm({
|
||||
}: {
|
||||
llmProviderDescriptor: WellKnownLLMProviderDescriptor;
|
||||
onClose: () => void;
|
||||
existingLlmProvider?: FullLLMProvider;
|
||||
existingLlmProvider?: LLMProviderView;
|
||||
shouldMarkAsDefault?: boolean;
|
||||
hideAdvanced?: boolean;
|
||||
setPopup?: (popup: PopupSpec) => void;
|
||||
@ -73,6 +73,7 @@ export function LLMProviderUpdateForm({
|
||||
defaultModelsByProvider[llmProviderDescriptor.name] ||
|
||||
[],
|
||||
deployment_name: existingLlmProvider?.deployment_name,
|
||||
api_key_changed: false,
|
||||
};
|
||||
|
||||
// Setup validation schema if required
|
||||
@ -113,6 +114,7 @@ export function LLMProviderUpdateForm({
|
||||
is_public: Yup.boolean().required(),
|
||||
groups: Yup.array().of(Yup.number()),
|
||||
display_model_names: Yup.array().of(Yup.string()),
|
||||
api_key_changed: Yup.boolean(),
|
||||
});
|
||||
|
||||
return (
|
||||
@ -122,6 +124,8 @@ export function LLMProviderUpdateForm({
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
setSubmitting(true);
|
||||
|
||||
values.api_key_changed = values.api_key !== initialValues.api_key;
|
||||
|
||||
// test the configuration
|
||||
if (!isEqual(values, initialValues)) {
|
||||
setIsTesting(true);
|
||||
@ -180,7 +184,7 @@ export function LLMProviderUpdateForm({
|
||||
}
|
||||
|
||||
if (shouldMarkAsDefault) {
|
||||
const newLlmProvider = (await response.json()) as FullLLMProvider;
|
||||
const newLlmProvider = (await response.json()) as LLMProviderView;
|
||||
const setDefaultResponse = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
|
||||
{
|
||||
|
@ -53,14 +53,14 @@ export interface LLMProvider {
|
||||
is_default_vision_provider: boolean | null;
|
||||
}
|
||||
|
||||
export interface FullLLMProvider extends LLMProvider {
|
||||
export interface LLMProviderView extends LLMProvider {
|
||||
id: number;
|
||||
is_default_provider: boolean | null;
|
||||
model_names: string[];
|
||||
icon?: React.FC<{ size?: number; className?: string }>;
|
||||
}
|
||||
|
||||
export interface VisionProvider extends FullLLMProvider {
|
||||
export interface VisionProvider extends LLMProviderView {
|
||||
vision_models: string[];
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
import {
|
||||
FullLLMProvider,
|
||||
LLMProviderView,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
import { User } from "@/lib/types";
|
||||
@ -36,7 +36,7 @@ export async function checkLlmProvider(user: User | null) {
|
||||
const [providerResponse, optionsResponse, defaultCheckResponse] =
|
||||
await Promise.all(tasks);
|
||||
|
||||
let providers: FullLLMProvider[] = [];
|
||||
let providers: LLMProviderView[] = [];
|
||||
if (providerResponse?.ok) {
|
||||
providers = await providerResponse.json();
|
||||
}
|
||||
|
@ -3,7 +3,7 @@ import { CCPairBasicInfo, DocumentSet, User } from "../types";
|
||||
import { getCurrentUserSS } from "../userSS";
|
||||
import { fetchSS } from "../utilsSS";
|
||||
import {
|
||||
FullLLMProvider,
|
||||
LLMProviderView,
|
||||
getProviderIcon,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
import { ToolSnapshot } from "../tools/interfaces";
|
||||
@ -16,7 +16,7 @@ export async function fetchAssistantEditorInfoSS(
|
||||
{
|
||||
ccPairs: CCPairBasicInfo[];
|
||||
documentSets: DocumentSet[];
|
||||
llmProviders: FullLLMProvider[];
|
||||
llmProviders: LLMProviderView[];
|
||||
user: User | null;
|
||||
existingPersona: Persona | null;
|
||||
tools: ToolSnapshot[];
|
||||
@ -83,7 +83,7 @@ export async function fetchAssistantEditorInfoSS(
|
||||
];
|
||||
}
|
||||
|
||||
const llmProviders = (await llmProvidersResponse.json()) as FullLLMProvider[];
|
||||
const llmProviders = (await llmProvidersResponse.json()) as LLMProviderView[];
|
||||
|
||||
if (personaId && personaResponse && !personaResponse.ok) {
|
||||
return [null, `Failed to fetch Persona - ${await personaResponse.text()}`];
|
||||
|
Loading…
x
Reference in New Issue
Block a user