mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-21 14:12:42 +02:00
Add support for more /models/list formats (#3739)
This commit is contained in:
@@ -14,8 +14,16 @@ from onyx.db.models import LLMProvider
|
||||
|
||||
def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
# Handle case where response is wrapped in a "data" field
|
||||
if isinstance(model_list_json, dict) and "data" in model_list_json:
|
||||
model_list_json = model_list_json["data"]
|
||||
if isinstance(model_list_json, dict):
|
||||
if "data" in model_list_json:
|
||||
model_list_json = model_list_json["data"]
|
||||
elif "models" in model_list_json:
|
||||
model_list_json = model_list_json["models"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid response from API - expected dict with 'data' or "
|
||||
f"'models' field, got {type(model_list_json)}"
|
||||
)
|
||||
|
||||
if not isinstance(model_list_json, list):
|
||||
raise ValueError(
|
||||
@@ -27,11 +35,18 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
for item in model_list_json:
|
||||
if isinstance(item, str):
|
||||
model_names.append(item)
|
||||
elif isinstance(item, dict) and "model_name" in item:
|
||||
model_names.append(item["model_name"])
|
||||
elif isinstance(item, dict):
|
||||
if "model_name" in item:
|
||||
model_names.append(item["model_name"])
|
||||
elif "id" in item:
|
||||
model_names.append(item["id"])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid item in model list - expected dict with model_name or id, got {type(item)}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid item in model list - expected string or dict with model_name, got {type(item)}"
|
||||
f"Invalid item in model list - expected string or dict, got {type(item)}"
|
||||
)
|
||||
|
||||
return model_names
|
||||
|
@@ -0,0 +1,92 @@
|
||||
import pytest
|
||||
|
||||
from onyx.background.celery.tasks.llm_model_update.tasks import (
|
||||
_process_model_list_response,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_data,expected_result,expected_error,error_match",
|
||||
[
|
||||
# Success cases
|
||||
(
|
||||
["gpt-4", "gpt-3.5-turbo", "claude-2"],
|
||||
["gpt-4", "gpt-3.5-turbo", "claude-2"],
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
[
|
||||
{"model_name": "gpt-4", "other_field": "value"},
|
||||
{"model_name": "gpt-3.5-turbo", "other_field": "value"},
|
||||
],
|
||||
["gpt-4", "gpt-3.5-turbo"],
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
[
|
||||
{"id": "gpt-4", "other_field": "value"},
|
||||
{"id": "gpt-3.5-turbo", "other_field": "value"},
|
||||
],
|
||||
["gpt-4", "gpt-3.5-turbo"],
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
{"data": ["gpt-4", "gpt-3.5-turbo"]},
|
||||
["gpt-4", "gpt-3.5-turbo"],
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
{"models": ["gpt-4", "gpt-3.5-turbo"]},
|
||||
["gpt-4", "gpt-3.5-turbo"],
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
{"models": [{"id": "gpt-4"}, {"id": "gpt-3.5-turbo"}]},
|
||||
["gpt-4", "gpt-3.5-turbo"],
|
||||
None,
|
||||
None,
|
||||
),
|
||||
# Error cases
|
||||
(
|
||||
"not a list",
|
||||
None,
|
||||
ValueError,
|
||||
"Invalid response from API - expected list",
|
||||
),
|
||||
(
|
||||
{"wrong_field": []},
|
||||
None,
|
||||
ValueError,
|
||||
"Invalid response from API - expected dict with 'data' or 'models' field",
|
||||
),
|
||||
(
|
||||
[{"wrong_field": "value"}],
|
||||
None,
|
||||
ValueError,
|
||||
"Invalid item in model list - expected dict with model_name or id",
|
||||
),
|
||||
(
|
||||
[42],
|
||||
None,
|
||||
ValueError,
|
||||
"Invalid item in model list - expected string or dict",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_process_model_list_response(
|
||||
input_data: dict | list,
|
||||
expected_result: list[str] | None,
|
||||
expected_error: type[Exception] | None,
|
||||
error_match: str | None,
|
||||
) -> None:
|
||||
if expected_error:
|
||||
with pytest.raises(expected_error, match=error_match):
|
||||
_process_model_list_response(input_data)
|
||||
else:
|
||||
result = _process_model_list_response(input_data)
|
||||
assert result == expected_result
|
Reference in New Issue
Block a user