Add tests for some LLM provider endpoints + small logic change to ensure that display_model_names is not empty

This commit is contained in:
Weves 2025-01-09 20:03:02 -08:00 committed by Chris Weaver
parent bf78fb79f8
commit 1470b7e038
4 changed files with 161 additions and 40 deletions

View File

@ -142,19 +142,20 @@ def put_llm_provider(
detail=f"LLM Provider with name {llm_provider.name} already exists",
)
# Ensure default_model_name and fast_default_model_name are in display_model_names
# This is necessary for custom models and Bedrock/Azure models
if llm_provider.display_model_names is None:
llm_provider.display_model_names = []
if llm_provider.display_model_names is not None:
# Ensure default_model_name and fast_default_model_name are in display_model_names
# This is necessary for custom models and Bedrock/Azure models
if llm_provider.default_model_name not in llm_provider.display_model_names:
llm_provider.display_model_names.append(llm_provider.default_model_name)
if llm_provider.default_model_name not in llm_provider.display_model_names:
llm_provider.display_model_names.append(llm_provider.default_model_name)
if (
llm_provider.fast_default_model_name
and llm_provider.fast_default_model_name not in llm_provider.display_model_names
):
llm_provider.display_model_names.append(llm_provider.fast_default_model_name)
if (
llm_provider.fast_default_model_name
and llm_provider.fast_default_model_name
not in llm_provider.display_model_names
):
llm_provider.display_model_names.append(
llm_provider.fast_default_model_name
)
try:
return upsert_llm_provider(

View File

@ -4,8 +4,12 @@ from collections.abc import Generator
import pytest
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
from onyx.db.engine import get_session_context_manager
from onyx.db.search_settings import get_current_search_settings
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.reset import reset_all
from tests.integration.common_utils.reset import reset_all_multitenant
@ -57,6 +61,30 @@ def new_admin_user(reset: None) -> DATestUser | None:
return None
@pytest.fixture
def admin_user() -> DATestUser | None:
try:
return UserManager.create(name="admin_user")
except Exception:
pass
try:
return UserManager.login_as_user(
DATestUser(
id="",
email=build_email("admin_user"),
password=DEFAULT_PASSWORD,
headers=GENERAL_HEADERS,
role=UserRole.ADMIN,
is_active=True,
)
)
except Exception:
pass
return None
@pytest.fixture
def reset_multitenant() -> None:
reset_all_multitenant()

View File

@ -7,40 +7,12 @@ import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user import UserRole
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
@pytest.fixture
def admin_user() -> DATestUser | None:
try:
return UserManager.create("admin_user")
except Exception:
pass
try:
return UserManager.login_as_user(
DATestUser(
id="",
email=build_email("admin_user"),
password=DEFAULT_PASSWORD,
headers=GENERAL_HEADERS,
role=UserRole.ADMIN,
is_active=True,
)
)
except Exception:
pass
return None
@pytest.fixture
def llm_provider(admin_user: DATestUser | None) -> DATestLLMProvider:
return LLMProviderManager.create(user_performing_action=admin_user)

View File

@ -0,0 +1,120 @@
import uuid
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.test_models import DATestUser
_DEFAULT_MODELS = ["gpt-4", "gpt-4o"]
def _get_provider_by_id(admin_user: DATestUser, provider_id: str) -> dict | None:
"""Utility function to fetch an LLM provider by ID"""
response = requests.get(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
)
assert response.status_code == 200
providers = response.json()
return next((p for p in providers if p["id"] == provider_id), None)
def test_create_llm_provider_without_display_model_names(
admin_user: DATestUser,
) -> None:
"""Test creating an LLM provider without specifying
display_model_names and verify it's null in response"""
# Create LLM provider without model_names
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
json={
"name": str(uuid.uuid4()),
"provider": "openai",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,
"groups": [],
},
)
assert response.status_code == 200
created_provider = response.json()
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
# Verify model_names is None/null
assert provider_data is not None
assert provider_data["model_names"] == _DEFAULT_MODELS
assert provider_data["default_model_name"] == _DEFAULT_MODELS[0]
assert provider_data["display_model_names"] is None
def test_update_llm_provider_model_names(admin_user: DATestUser) -> None:
"""Test updating an LLM provider's model_names"""
# First create provider without model_names
name = str(uuid.uuid4())
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
json={
"name": name,
"provider": "openai",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": [_DEFAULT_MODELS[0]],
"is_public": True,
"groups": [],
},
)
assert response.status_code == 200
created_provider = response.json()
# Update with model_names
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
json={
"id": created_provider["id"],
"name": name,
"provider": created_provider["provider"],
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,
"groups": [],
},
)
assert response.status_code == 200
# Verify update
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
assert provider_data is not None
assert provider_data["model_names"] == _DEFAULT_MODELS
def test_delete_llm_provider(admin_user: DATestUser) -> None:
"""Test deleting an LLM provider"""
# Create a provider
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
json={
"name": "test-provider-delete",
"provider": "openai",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,
"groups": [],
},
)
assert response.status_code == 200
created_provider = response.json()
# Delete the provider
response = requests.delete(
f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}",
headers=admin_user.headers,
)
assert response.status_code == 200
# Verify provider is deleted by checking it's not in the list
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
assert provider_data is None