Fix GPU status request in sync flow (#4318)

* Fix GPU status request in sync flow

* tweak

* Fix test

* Fix more tests
This commit is contained in:
Chris Weaver 2025-03-21 11:11:00 -07:00 committed by GitHub
parent 775c847f82
commit d123713c00
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 22 additions and 8 deletions

View File

@ -30,7 +30,7 @@ from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD
from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.gpu_utils import gpu_status_request from onyx.utils.gpu_utils import fast_gpu_status_request
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
@ -88,7 +88,9 @@ class Answer:
rerank_settings is not None rerank_settings is not None
and rerank_settings.rerank_provider_type is not None and rerank_settings.rerank_provider_type is not None
) )
allow_agent_reranking = gpu_status_request() or using_cloud_reranking allow_agent_reranking = (
fast_gpu_status_request(indexing=False) or using_cloud_reranking
)
# TODO: this is a hack to force the query to be used for the search tool # TODO: this is a hack to force the query to be used for the search tool
# this should be removed once we fully unify graph inputs (i.e. # this should be removed once we fully unify graph inputs (i.e.

View File

@ -324,7 +324,7 @@ def update_default_multipass_indexing(db_session: Session) -> None:
logger.info( logger.info(
"No existing docs or connectors found. Checking GPU availability for multipass indexing." "No existing docs or connectors found. Checking GPU availability for multipass indexing."
) )
gpu_available = gpu_status_request() gpu_available = gpu_status_request(indexing=True)
logger.info(f"GPU available: {gpu_available}") logger.info(f"GPU available: {gpu_available}")
current_settings = get_current_search_settings(db_session) current_settings = get_current_search_settings(db_session)

View File

@ -1,3 +1,5 @@
from functools import lru_cache
import requests import requests
from retry import retry from retry import retry
@ -10,8 +12,7 @@ from shared_configs.configs import MODEL_SERVER_PORT
logger = setup_logger() logger = setup_logger()
@retry(tries=5, delay=5) def _get_gpu_status_from_model_server(indexing: bool) -> bool:
def gpu_status_request(indexing: bool = True) -> bool:
if indexing: if indexing:
model_server_url = f"{INDEXING_MODEL_SERVER_HOST}:{INDEXING_MODEL_SERVER_PORT}" model_server_url = f"{INDEXING_MODEL_SERVER_HOST}:{INDEXING_MODEL_SERVER_PORT}"
else: else:
@ -28,3 +29,14 @@ def gpu_status_request(indexing: bool = True) -> bool:
except requests.RequestException as e: except requests.RequestException as e:
logger.error(f"Error: Unable to fetch GPU status. Error: {str(e)}") logger.error(f"Error: Unable to fetch GPU status. Error: {str(e)}")
raise # Re-raise exception to trigger a retry raise # Re-raise exception to trigger a retry
@retry(tries=5, delay=5)
def gpu_status_request(indexing: bool) -> bool:
return _get_gpu_status_from_model_server(indexing)
@lru_cache(maxsize=1)
def fast_gpu_status_request(indexing: bool) -> bool:
"""For use in sync flows, where we don't want to retry / we want to cache this."""
return gpu_status_request(indexing=indexing)

View File

@ -50,7 +50,7 @@ def answer_instance(
mocker: MockerFixture, mocker: MockerFixture,
) -> Answer: ) -> Answer:
mocker.patch( mocker.patch(
"onyx.chat.answer.gpu_status_request", "onyx.chat.answer.fast_gpu_status_request",
return_value=True, return_value=True,
) )
return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config) return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config)
@ -400,7 +400,7 @@ def test_no_slow_reranking(
mocker: MockerFixture, mocker: MockerFixture,
) -> None: ) -> None:
mocker.patch( mocker.patch(
"onyx.chat.answer.gpu_status_request", "onyx.chat.answer.fast_gpu_status_request",
return_value=gpu_enabled, return_value=gpu_enabled,
) )
rerank_settings = ( rerank_settings = (

View File

@ -39,7 +39,7 @@ def test_skip_gen_ai_answer_generation_flag(
mocker: MockerFixture, mocker: MockerFixture,
) -> None: ) -> None:
mocker.patch( mocker.patch(
"onyx.chat.answer.gpu_status_request", "onyx.chat.answer.fast_gpu_status_request",
return_value=True, return_value=True,
) )
question = config["question"] question = config["question"]