From d123713c006f4a24cc931a3b6bdc2b9980c03e0e Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Fri, 21 Mar 2025 11:11:00 -0700 Subject: [PATCH] Fix GPU status request in sync flow (#4318) * Fix GPU status request in sync flow * tweak * Fix test * Fix more tests --- backend/onyx/chat/answer.py | 6 ++++-- backend/onyx/setup.py | 2 +- backend/onyx/utils/gpu_utils.py | 16 ++++++++++++++-- backend/tests/unit/onyx/chat/test_answer.py | 4 ++-- backend/tests/unit/onyx/chat/test_skip_gen_ai.py | 2 +- 5 files changed, 22 insertions(+), 8 deletions(-) diff --git a/backend/onyx/chat/answer.py b/backend/onyx/chat/answer.py index eb9b2130ddc..0bf937b6c2c 100644 --- a/backend/onyx/chat/answer.py +++ b/backend/onyx/chat/answer.py @@ -30,7 +30,7 @@ from onyx.tools.tool import Tool from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.utils import explicit_tool_calling_supported -from onyx.utils.gpu_utils import gpu_status_request +from onyx.utils.gpu_utils import fast_gpu_status_request from onyx.utils.logger import setup_logger logger = setup_logger() @@ -88,7 +88,9 @@ class Answer: rerank_settings is not None and rerank_settings.rerank_provider_type is not None ) - allow_agent_reranking = gpu_status_request() or using_cloud_reranking + allow_agent_reranking = ( + fast_gpu_status_request(indexing=False) or using_cloud_reranking + ) # TODO: this is a hack to force the query to be used for the search tool # this should be removed once we fully unify graph inputs (i.e. diff --git a/backend/onyx/setup.py b/backend/onyx/setup.py index 750b35d8d14..b1d2a4c04d6 100644 --- a/backend/onyx/setup.py +++ b/backend/onyx/setup.py @@ -324,7 +324,7 @@ def update_default_multipass_indexing(db_session: Session) -> None: logger.info( "No existing docs or connectors found. Checking GPU availability for multipass indexing." ) - gpu_available = gpu_status_request() + gpu_available = gpu_status_request(indexing=True) logger.info(f"GPU available: {gpu_available}") current_settings = get_current_search_settings(db_session) diff --git a/backend/onyx/utils/gpu_utils.py b/backend/onyx/utils/gpu_utils.py index 75acc0232b9..c348e40b136 100644 --- a/backend/onyx/utils/gpu_utils.py +++ b/backend/onyx/utils/gpu_utils.py @@ -1,3 +1,5 @@ +from functools import lru_cache + import requests from retry import retry @@ -10,8 +12,7 @@ from shared_configs.configs import MODEL_SERVER_PORT logger = setup_logger() -@retry(tries=5, delay=5) -def gpu_status_request(indexing: bool = True) -> bool: +def _get_gpu_status_from_model_server(indexing: bool) -> bool: if indexing: model_server_url = f"{INDEXING_MODEL_SERVER_HOST}:{INDEXING_MODEL_SERVER_PORT}" else: @@ -28,3 +29,14 @@ def gpu_status_request(indexing: bool = True) -> bool: except requests.RequestException as e: logger.error(f"Error: Unable to fetch GPU status. Error: {str(e)}") raise # Re-raise exception to trigger a retry + + +@retry(tries=5, delay=5) +def gpu_status_request(indexing: bool) -> bool: + return _get_gpu_status_from_model_server(indexing) + + +@lru_cache(maxsize=1) +def fast_gpu_status_request(indexing: bool) -> bool: + """For use in sync flows, where we don't want to retry / we want to cache this.""" + return gpu_status_request(indexing=indexing) diff --git a/backend/tests/unit/onyx/chat/test_answer.py b/backend/tests/unit/onyx/chat/test_answer.py index 8e2a5f448b7..34f46fff9da 100644 --- a/backend/tests/unit/onyx/chat/test_answer.py +++ b/backend/tests/unit/onyx/chat/test_answer.py @@ -50,7 +50,7 @@ def answer_instance( mocker: MockerFixture, ) -> Answer: mocker.patch( - "onyx.chat.answer.gpu_status_request", + "onyx.chat.answer.fast_gpu_status_request", return_value=True, ) return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config) @@ -400,7 +400,7 @@ def test_no_slow_reranking( mocker: MockerFixture, ) -> None: mocker.patch( - "onyx.chat.answer.gpu_status_request", + "onyx.chat.answer.fast_gpu_status_request", return_value=gpu_enabled, ) rerank_settings = ( diff --git a/backend/tests/unit/onyx/chat/test_skip_gen_ai.py b/backend/tests/unit/onyx/chat/test_skip_gen_ai.py index 146a08f608a..c1c17e36254 100644 --- a/backend/tests/unit/onyx/chat/test_skip_gen_ai.py +++ b/backend/tests/unit/onyx/chat/test_skip_gen_ai.py @@ -39,7 +39,7 @@ def test_skip_gen_ai_answer_generation_flag( mocker: MockerFixture, ) -> None: mocker.patch( - "onyx.chat.answer.gpu_status_request", + "onyx.chat.answer.fast_gpu_status_request", return_value=True, ) question = config["question"]