mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-23 10:11:00 +02:00
Add tests for some LLM provider endpoints + small logic change to ensure that display_model_names is not empty
This commit is contained in:
parent
bf78fb79f8
commit
1470b7e038
@ -142,19 +142,20 @@ def put_llm_provider(
|
||||
detail=f"LLM Provider with name {llm_provider.name} already exists",
|
||||
)
|
||||
|
||||
if llm_provider.display_model_names is not None:
|
||||
# Ensure default_model_name and fast_default_model_name are in display_model_names
|
||||
# This is necessary for custom models and Bedrock/Azure models
|
||||
if llm_provider.display_model_names is None:
|
||||
llm_provider.display_model_names = []
|
||||
|
||||
if llm_provider.default_model_name not in llm_provider.display_model_names:
|
||||
llm_provider.display_model_names.append(llm_provider.default_model_name)
|
||||
|
||||
if (
|
||||
llm_provider.fast_default_model_name
|
||||
and llm_provider.fast_default_model_name not in llm_provider.display_model_names
|
||||
and llm_provider.fast_default_model_name
|
||||
not in llm_provider.display_model_names
|
||||
):
|
||||
llm_provider.display_model_names.append(llm_provider.fast_default_model_name)
|
||||
llm_provider.display_model_names.append(
|
||||
llm_provider.fast_default_model_name
|
||||
)
|
||||
|
||||
try:
|
||||
return upsert_llm_provider(
|
||||
|
@ -4,8 +4,12 @@ from collections.abc import Generator
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.reset import reset_all
|
||||
from tests.integration.common_utils.reset import reset_all_multitenant
|
||||
@ -57,6 +61,30 @@ def new_admin_user(reset: None) -> DATestUser | None:
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user() -> DATestUser | None:
|
||||
try:
|
||||
return UserManager.create(name="admin_user")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email("admin_user"),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.ADMIN,
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_multitenant() -> None:
|
||||
reset_all_multitenant()
|
||||
|
@ -7,40 +7,12 @@ import requests
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user import UserRole
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user() -> DATestUser | None:
|
||||
try:
|
||||
return UserManager.create("admin_user")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email("admin_user"),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.ADMIN,
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_provider(admin_user: DATestUser | None) -> DATestLLMProvider:
|
||||
return LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
@ -0,0 +1,120 @@
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
_DEFAULT_MODELS = ["gpt-4", "gpt-4o"]
|
||||
|
||||
|
||||
def _get_provider_by_id(admin_user: DATestUser, provider_id: str) -> dict | None:
|
||||
"""Utility function to fetch an LLM provider by ID"""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
return next((p for p in providers if p["id"] == provider_id), None)
|
||||
|
||||
|
||||
def test_create_llm_provider_without_display_model_names(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test creating an LLM provider without specifying
|
||||
display_model_names and verify it's null in response"""
|
||||
# Create LLM provider without model_names
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": str(uuid.uuid4()),
|
||||
"provider": "openai",
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": _DEFAULT_MODELS,
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
created_provider = response.json()
|
||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||
|
||||
# Verify model_names is None/null
|
||||
assert provider_data is not None
|
||||
assert provider_data["model_names"] == _DEFAULT_MODELS
|
||||
assert provider_data["default_model_name"] == _DEFAULT_MODELS[0]
|
||||
assert provider_data["display_model_names"] is None
|
||||
|
||||
|
||||
def test_update_llm_provider_model_names(admin_user: DATestUser) -> None:
|
||||
"""Test updating an LLM provider's model_names"""
|
||||
# First create provider without model_names
|
||||
name = str(uuid.uuid4())
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": name,
|
||||
"provider": "openai",
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": [_DEFAULT_MODELS[0]],
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
created_provider = response.json()
|
||||
|
||||
# Update with model_names
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"id": created_provider["id"],
|
||||
"name": name,
|
||||
"provider": created_provider["provider"],
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": _DEFAULT_MODELS,
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify update
|
||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||
assert provider_data is not None
|
||||
assert provider_data["model_names"] == _DEFAULT_MODELS
|
||||
|
||||
|
||||
def test_delete_llm_provider(admin_user: DATestUser) -> None:
|
||||
"""Test deleting an LLM provider"""
|
||||
# Create a provider
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": "test-provider-delete",
|
||||
"provider": "openai",
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": _DEFAULT_MODELS,
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
created_provider = response.json()
|
||||
|
||||
# Delete the provider
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify provider is deleted by checking it's not in the list
|
||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||
assert provider_data is None
|
Loading…
x
Reference in New Issue
Block a user