mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-24 18:50:06 +02:00
Add tests for some LLM provider endpoints + small logic change to ensure that display_model_names is not empty
This commit is contained in:
parent
bf78fb79f8
commit
1470b7e038
@ -142,19 +142,20 @@ def put_llm_provider(
|
|||||||
detail=f"LLM Provider with name {llm_provider.name} already exists",
|
detail=f"LLM Provider with name {llm_provider.name} already exists",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure default_model_name and fast_default_model_name are in display_model_names
|
if llm_provider.display_model_names is not None:
|
||||||
# This is necessary for custom models and Bedrock/Azure models
|
# Ensure default_model_name and fast_default_model_name are in display_model_names
|
||||||
if llm_provider.display_model_names is None:
|
# This is necessary for custom models and Bedrock/Azure models
|
||||||
llm_provider.display_model_names = []
|
if llm_provider.default_model_name not in llm_provider.display_model_names:
|
||||||
|
llm_provider.display_model_names.append(llm_provider.default_model_name)
|
||||||
|
|
||||||
if llm_provider.default_model_name not in llm_provider.display_model_names:
|
if (
|
||||||
llm_provider.display_model_names.append(llm_provider.default_model_name)
|
llm_provider.fast_default_model_name
|
||||||
|
and llm_provider.fast_default_model_name
|
||||||
if (
|
not in llm_provider.display_model_names
|
||||||
llm_provider.fast_default_model_name
|
):
|
||||||
and llm_provider.fast_default_model_name not in llm_provider.display_model_names
|
llm_provider.display_model_names.append(
|
||||||
):
|
llm_provider.fast_default_model_name
|
||||||
llm_provider.display_model_names.append(llm_provider.fast_default_model_name)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return upsert_llm_provider(
|
return upsert_llm_provider(
|
||||||
|
@ -4,8 +4,12 @@ from collections.abc import Generator
|
|||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from onyx.auth.schemas import UserRole
|
||||||
from onyx.db.engine import get_session_context_manager
|
from onyx.db.engine import get_session_context_manager
|
||||||
from onyx.db.search_settings import get_current_search_settings
|
from onyx.db.search_settings import get_current_search_settings
|
||||||
|
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||||
|
from tests.integration.common_utils.managers.user import build_email
|
||||||
|
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||||
from tests.integration.common_utils.managers.user import UserManager
|
from tests.integration.common_utils.managers.user import UserManager
|
||||||
from tests.integration.common_utils.reset import reset_all
|
from tests.integration.common_utils.reset import reset_all
|
||||||
from tests.integration.common_utils.reset import reset_all_multitenant
|
from tests.integration.common_utils.reset import reset_all_multitenant
|
||||||
@ -57,6 +61,30 @@ def new_admin_user(reset: None) -> DATestUser | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def admin_user() -> DATestUser | None:
|
||||||
|
try:
|
||||||
|
return UserManager.create(name="admin_user")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
return UserManager.login_as_user(
|
||||||
|
DATestUser(
|
||||||
|
id="",
|
||||||
|
email=build_email("admin_user"),
|
||||||
|
password=DEFAULT_PASSWORD,
|
||||||
|
headers=GENERAL_HEADERS,
|
||||||
|
role=UserRole.ADMIN,
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def reset_multitenant() -> None:
|
def reset_multitenant() -> None:
|
||||||
reset_all_multitenant()
|
reset_all_multitenant()
|
||||||
|
@ -7,40 +7,12 @@ import requests
|
|||||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||||
from tests.integration.common_utils.managers.user import build_email
|
|
||||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
|
||||||
from tests.integration.common_utils.managers.user import UserManager
|
|
||||||
from tests.integration.common_utils.managers.user import UserRole
|
|
||||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||||
from tests.integration.common_utils.test_models import DATestUser
|
from tests.integration.common_utils.test_models import DATestUser
|
||||||
|
|
||||||
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
|
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def admin_user() -> DATestUser | None:
|
|
||||||
try:
|
|
||||||
return UserManager.create("admin_user")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
return UserManager.login_as_user(
|
|
||||||
DATestUser(
|
|
||||||
id="",
|
|
||||||
email=build_email("admin_user"),
|
|
||||||
password=DEFAULT_PASSWORD,
|
|
||||||
headers=GENERAL_HEADERS,
|
|
||||||
role=UserRole.ADMIN,
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def llm_provider(admin_user: DATestUser | None) -> DATestLLMProvider:
|
def llm_provider(admin_user: DATestUser | None) -> DATestLLMProvider:
|
||||||
return LLMProviderManager.create(user_performing_action=admin_user)
|
return LLMProviderManager.create(user_performing_action=admin_user)
|
||||||
|
@ -0,0 +1,120 @@
|
|||||||
|
import uuid
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||||
|
from tests.integration.common_utils.test_models import DATestUser
|
||||||
|
|
||||||
|
|
||||||
|
_DEFAULT_MODELS = ["gpt-4", "gpt-4o"]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_provider_by_id(admin_user: DATestUser, provider_id: str) -> dict | None:
|
||||||
|
"""Utility function to fetch an LLM provider by ID"""
|
||||||
|
response = requests.get(
|
||||||
|
f"{API_SERVER_URL}/admin/llm/provider",
|
||||||
|
headers=admin_user.headers,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
providers = response.json()
|
||||||
|
return next((p for p in providers if p["id"] == provider_id), None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_llm_provider_without_display_model_names(
|
||||||
|
admin_user: DATestUser,
|
||||||
|
) -> None:
|
||||||
|
"""Test creating an LLM provider without specifying
|
||||||
|
display_model_names and verify it's null in response"""
|
||||||
|
# Create LLM provider without model_names
|
||||||
|
response = requests.put(
|
||||||
|
f"{API_SERVER_URL}/admin/llm/provider",
|
||||||
|
headers=admin_user.headers,
|
||||||
|
json={
|
||||||
|
"name": str(uuid.uuid4()),
|
||||||
|
"provider": "openai",
|
||||||
|
"default_model_name": _DEFAULT_MODELS[0],
|
||||||
|
"model_names": _DEFAULT_MODELS,
|
||||||
|
"is_public": True,
|
||||||
|
"groups": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
created_provider = response.json()
|
||||||
|
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||||
|
|
||||||
|
# Verify model_names is None/null
|
||||||
|
assert provider_data is not None
|
||||||
|
assert provider_data["model_names"] == _DEFAULT_MODELS
|
||||||
|
assert provider_data["default_model_name"] == _DEFAULT_MODELS[0]
|
||||||
|
assert provider_data["display_model_names"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_llm_provider_model_names(admin_user: DATestUser) -> None:
|
||||||
|
"""Test updating an LLM provider's model_names"""
|
||||||
|
# First create provider without model_names
|
||||||
|
name = str(uuid.uuid4())
|
||||||
|
response = requests.put(
|
||||||
|
f"{API_SERVER_URL}/admin/llm/provider",
|
||||||
|
headers=admin_user.headers,
|
||||||
|
json={
|
||||||
|
"name": name,
|
||||||
|
"provider": "openai",
|
||||||
|
"default_model_name": _DEFAULT_MODELS[0],
|
||||||
|
"model_names": [_DEFAULT_MODELS[0]],
|
||||||
|
"is_public": True,
|
||||||
|
"groups": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
created_provider = response.json()
|
||||||
|
|
||||||
|
# Update with model_names
|
||||||
|
response = requests.put(
|
||||||
|
f"{API_SERVER_URL}/admin/llm/provider",
|
||||||
|
headers=admin_user.headers,
|
||||||
|
json={
|
||||||
|
"id": created_provider["id"],
|
||||||
|
"name": name,
|
||||||
|
"provider": created_provider["provider"],
|
||||||
|
"default_model_name": _DEFAULT_MODELS[0],
|
||||||
|
"model_names": _DEFAULT_MODELS,
|
||||||
|
"is_public": True,
|
||||||
|
"groups": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Verify update
|
||||||
|
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||||
|
assert provider_data is not None
|
||||||
|
assert provider_data["model_names"] == _DEFAULT_MODELS
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_llm_provider(admin_user: DATestUser) -> None:
|
||||||
|
"""Test deleting an LLM provider"""
|
||||||
|
# Create a provider
|
||||||
|
response = requests.put(
|
||||||
|
f"{API_SERVER_URL}/admin/llm/provider",
|
||||||
|
headers=admin_user.headers,
|
||||||
|
json={
|
||||||
|
"name": "test-provider-delete",
|
||||||
|
"provider": "openai",
|
||||||
|
"default_model_name": _DEFAULT_MODELS[0],
|
||||||
|
"model_names": _DEFAULT_MODELS,
|
||||||
|
"is_public": True,
|
||||||
|
"groups": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
created_provider = response.json()
|
||||||
|
|
||||||
|
# Delete the provider
|
||||||
|
response = requests.delete(
|
||||||
|
f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}",
|
||||||
|
headers=admin_user.headers,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Verify provider is deleted by checking it's not in the list
|
||||||
|
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||||
|
assert provider_data is None
|
Loading…
x
Reference in New Issue
Block a user