Add support for more /models/list formats (#3739)

This commit is contained in:
Chris Weaver
2025-01-24 10:25:19 -08:00
committed by GitHub
parent 3c37764974
commit 5d6a18f358
2 changed files with 112 additions and 5 deletions

View File

@@ -14,8 +14,16 @@ from onyx.db.models import LLMProvider
def _process_model_list_response(model_list_json: Any) -> list[str]:
# Handle case where response is wrapped in a "data" field
if isinstance(model_list_json, dict) and "data" in model_list_json:
model_list_json = model_list_json["data"]
if isinstance(model_list_json, dict):
if "data" in model_list_json:
model_list_json = model_list_json["data"]
elif "models" in model_list_json:
model_list_json = model_list_json["models"]
else:
raise ValueError(
"Invalid response from API - expected dict with 'data' or "
f"'models' field, got {type(model_list_json)}"
)
if not isinstance(model_list_json, list):
raise ValueError(
@@ -27,11 +35,18 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
for item in model_list_json:
if isinstance(item, str):
model_names.append(item)
elif isinstance(item, dict) and "model_name" in item:
model_names.append(item["model_name"])
elif isinstance(item, dict):
if "model_name" in item:
model_names.append(item["model_name"])
elif "id" in item:
model_names.append(item["id"])
else:
raise ValueError(
f"Invalid item in model list - expected dict with model_name or id, got {type(item)}"
)
else:
raise ValueError(
f"Invalid item in model list - expected string or dict with model_name, got {type(item)}"
f"Invalid item in model list - expected string or dict, got {type(item)}"
)
return model_names

View File

@@ -0,0 +1,92 @@
import pytest
from onyx.background.celery.tasks.llm_model_update.tasks import (
_process_model_list_response,
)
@pytest.mark.parametrize(
"input_data,expected_result,expected_error,error_match",
[
# Success cases
(
["gpt-4", "gpt-3.5-turbo", "claude-2"],
["gpt-4", "gpt-3.5-turbo", "claude-2"],
None,
None,
),
(
[
{"model_name": "gpt-4", "other_field": "value"},
{"model_name": "gpt-3.5-turbo", "other_field": "value"},
],
["gpt-4", "gpt-3.5-turbo"],
None,
None,
),
(
[
{"id": "gpt-4", "other_field": "value"},
{"id": "gpt-3.5-turbo", "other_field": "value"},
],
["gpt-4", "gpt-3.5-turbo"],
None,
None,
),
(
{"data": ["gpt-4", "gpt-3.5-turbo"]},
["gpt-4", "gpt-3.5-turbo"],
None,
None,
),
(
{"models": ["gpt-4", "gpt-3.5-turbo"]},
["gpt-4", "gpt-3.5-turbo"],
None,
None,
),
(
{"models": [{"id": "gpt-4"}, {"id": "gpt-3.5-turbo"}]},
["gpt-4", "gpt-3.5-turbo"],
None,
None,
),
# Error cases
(
"not a list",
None,
ValueError,
"Invalid response from API - expected list",
),
(
{"wrong_field": []},
None,
ValueError,
"Invalid response from API - expected dict with 'data' or 'models' field",
),
(
[{"wrong_field": "value"}],
None,
ValueError,
"Invalid item in model list - expected dict with model_name or id",
),
(
[42],
None,
ValueError,
"Invalid item in model list - expected string or dict",
),
],
)
def test_process_model_list_response(
input_data: dict | list,
expected_result: list[str] | None,
expected_error: type[Exception] | None,
error_match: str | None,
) -> None:
if expected_error:
with pytest.raises(expected_error, match=error_match):
_process_model_list_response(input_data)
else:
result = _process_model_list_response(input_data)
assert result == expected_result