From 5d6a18f358381735c998ea5f51304993acba04fb Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Fri, 24 Jan 2025 10:25:19 -0800 Subject: [PATCH] Add support for more /models/list formats (#3739) --- .../celery/tasks/llm_model_update/tasks.py | 25 ++++- .../llm_model_update/test_llm_model_update.py | 92 +++++++++++++++++++ 2 files changed, 112 insertions(+), 5 deletions(-) create mode 100644 backend/tests/unit/onyx/celery/llm_model_update/test_llm_model_update.py diff --git a/backend/onyx/background/celery/tasks/llm_model_update/tasks.py b/backend/onyx/background/celery/tasks/llm_model_update/tasks.py index 4d58b100a2b1..7bd597661c46 100644 --- a/backend/onyx/background/celery/tasks/llm_model_update/tasks.py +++ b/backend/onyx/background/celery/tasks/llm_model_update/tasks.py @@ -14,8 +14,16 @@ from onyx.db.models import LLMProvider def _process_model_list_response(model_list_json: Any) -> list[str]: # Handle case where response is wrapped in a "data" field - if isinstance(model_list_json, dict) and "data" in model_list_json: - model_list_json = model_list_json["data"] + if isinstance(model_list_json, dict): + if "data" in model_list_json: + model_list_json = model_list_json["data"] + elif "models" in model_list_json: + model_list_json = model_list_json["models"] + else: + raise ValueError( + "Invalid response from API - expected dict with 'data' or " + f"'models' field, got {type(model_list_json)}" + ) if not isinstance(model_list_json, list): raise ValueError( @@ -27,11 +35,18 @@ def _process_model_list_response(model_list_json: Any) -> list[str]: for item in model_list_json: if isinstance(item, str): model_names.append(item) - elif isinstance(item, dict) and "model_name" in item: - model_names.append(item["model_name"]) + elif isinstance(item, dict): + if "model_name" in item: + model_names.append(item["model_name"]) + elif "id" in item: + model_names.append(item["id"]) + else: + raise ValueError( + f"Invalid item in model list - expected dict with model_name or id, got {type(item)}" + ) else: raise ValueError( - f"Invalid item in model list - expected string or dict with model_name, got {type(item)}" + f"Invalid item in model list - expected string or dict, got {type(item)}" ) return model_names diff --git a/backend/tests/unit/onyx/celery/llm_model_update/test_llm_model_update.py b/backend/tests/unit/onyx/celery/llm_model_update/test_llm_model_update.py new file mode 100644 index 000000000000..ed5c0f0d9fa1 --- /dev/null +++ b/backend/tests/unit/onyx/celery/llm_model_update/test_llm_model_update.py @@ -0,0 +1,92 @@ +import pytest + +from onyx.background.celery.tasks.llm_model_update.tasks import ( + _process_model_list_response, +) + + +@pytest.mark.parametrize( + "input_data,expected_result,expected_error,error_match", + [ + # Success cases + ( + ["gpt-4", "gpt-3.5-turbo", "claude-2"], + ["gpt-4", "gpt-3.5-turbo", "claude-2"], + None, + None, + ), + ( + [ + {"model_name": "gpt-4", "other_field": "value"}, + {"model_name": "gpt-3.5-turbo", "other_field": "value"}, + ], + ["gpt-4", "gpt-3.5-turbo"], + None, + None, + ), + ( + [ + {"id": "gpt-4", "other_field": "value"}, + {"id": "gpt-3.5-turbo", "other_field": "value"}, + ], + ["gpt-4", "gpt-3.5-turbo"], + None, + None, + ), + ( + {"data": ["gpt-4", "gpt-3.5-turbo"]}, + ["gpt-4", "gpt-3.5-turbo"], + None, + None, + ), + ( + {"models": ["gpt-4", "gpt-3.5-turbo"]}, + ["gpt-4", "gpt-3.5-turbo"], + None, + None, + ), + ( + {"models": [{"id": "gpt-4"}, {"id": "gpt-3.5-turbo"}]}, + ["gpt-4", "gpt-3.5-turbo"], + None, + None, + ), + # Error cases + ( + "not a list", + None, + ValueError, + "Invalid response from API - expected list", + ), + ( + {"wrong_field": []}, + None, + ValueError, + "Invalid response from API - expected dict with 'data' or 'models' field", + ), + ( + [{"wrong_field": "value"}], + None, + ValueError, + "Invalid item in model list - expected dict with model_name or id", + ), + ( + [42], + None, + ValueError, + "Invalid item in model list - expected string or dict", + ), + ], +) +def test_process_model_list_response( + input_data: dict | list, + expected_result: list[str] | None, + expected_error: type[Exception] | None, + error_match: str | None, +) -> None: + if expected_error: + with pytest.raises(expected_error, match=error_match): + _process_model_list_response(input_data) + else: + result = _process_model_list_response(input_data) + assert result == expected_result