mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-25 11:16:43 +02:00
Add support for more /models/list formats (#3739)
This commit is contained in:
@@ -14,8 +14,16 @@ from onyx.db.models import LLMProvider
|
|||||||
|
|
||||||
def _process_model_list_response(model_list_json: Any) -> list[str]:
|
def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||||
# Handle case where response is wrapped in a "data" field
|
# Handle case where response is wrapped in a "data" field
|
||||||
if isinstance(model_list_json, dict) and "data" in model_list_json:
|
if isinstance(model_list_json, dict):
|
||||||
model_list_json = model_list_json["data"]
|
if "data" in model_list_json:
|
||||||
|
model_list_json = model_list_json["data"]
|
||||||
|
elif "models" in model_list_json:
|
||||||
|
model_list_json = model_list_json["models"]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid response from API - expected dict with 'data' or "
|
||||||
|
f"'models' field, got {type(model_list_json)}"
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(model_list_json, list):
|
if not isinstance(model_list_json, list):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -27,11 +35,18 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
|
|||||||
for item in model_list_json:
|
for item in model_list_json:
|
||||||
if isinstance(item, str):
|
if isinstance(item, str):
|
||||||
model_names.append(item)
|
model_names.append(item)
|
||||||
elif isinstance(item, dict) and "model_name" in item:
|
elif isinstance(item, dict):
|
||||||
model_names.append(item["model_name"])
|
if "model_name" in item:
|
||||||
|
model_names.append(item["model_name"])
|
||||||
|
elif "id" in item:
|
||||||
|
model_names.append(item["id"])
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid item in model list - expected dict with model_name or id, got {type(item)}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid item in model list - expected string or dict with model_name, got {type(item)}"
|
f"Invalid item in model list - expected string or dict, got {type(item)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_names
|
return model_names
|
||||||
|
@@ -0,0 +1,92 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from onyx.background.celery.tasks.llm_model_update.tasks import (
|
||||||
|
_process_model_list_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_data,expected_result,expected_error,error_match",
|
||||||
|
[
|
||||||
|
# Success cases
|
||||||
|
(
|
||||||
|
["gpt-4", "gpt-3.5-turbo", "claude-2"],
|
||||||
|
["gpt-4", "gpt-3.5-turbo", "claude-2"],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[
|
||||||
|
{"model_name": "gpt-4", "other_field": "value"},
|
||||||
|
{"model_name": "gpt-3.5-turbo", "other_field": "value"},
|
||||||
|
],
|
||||||
|
["gpt-4", "gpt-3.5-turbo"],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[
|
||||||
|
{"id": "gpt-4", "other_field": "value"},
|
||||||
|
{"id": "gpt-3.5-turbo", "other_field": "value"},
|
||||||
|
],
|
||||||
|
["gpt-4", "gpt-3.5-turbo"],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"data": ["gpt-4", "gpt-3.5-turbo"]},
|
||||||
|
["gpt-4", "gpt-3.5-turbo"],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"models": ["gpt-4", "gpt-3.5-turbo"]},
|
||||||
|
["gpt-4", "gpt-3.5-turbo"],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"models": [{"id": "gpt-4"}, {"id": "gpt-3.5-turbo"}]},
|
||||||
|
["gpt-4", "gpt-3.5-turbo"],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
# Error cases
|
||||||
|
(
|
||||||
|
"not a list",
|
||||||
|
None,
|
||||||
|
ValueError,
|
||||||
|
"Invalid response from API - expected list",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"wrong_field": []},
|
||||||
|
None,
|
||||||
|
ValueError,
|
||||||
|
"Invalid response from API - expected dict with 'data' or 'models' field",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[{"wrong_field": "value"}],
|
||||||
|
None,
|
||||||
|
ValueError,
|
||||||
|
"Invalid item in model list - expected dict with model_name or id",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[42],
|
||||||
|
None,
|
||||||
|
ValueError,
|
||||||
|
"Invalid item in model list - expected string or dict",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_process_model_list_response(
|
||||||
|
input_data: dict | list,
|
||||||
|
expected_result: list[str] | None,
|
||||||
|
expected_error: type[Exception] | None,
|
||||||
|
error_match: str | None,
|
||||||
|
) -> None:
|
||||||
|
if expected_error:
|
||||||
|
with pytest.raises(expected_error, match=error_match):
|
||||||
|
_process_model_list_response(input_data)
|
||||||
|
else:
|
||||||
|
result = _process_model_list_response(input_data)
|
||||||
|
assert result == expected_result
|
Reference in New Issue
Block a user