Fix GPU status request in sync flow (#4318)

* Fix GPU status request in sync flow

* tweak

* Fix test

* Fix more tests
This commit is contained in:
Chris Weaver 2025-03-21 11:11:00 -07:00 committed by GitHub
parent 775c847f82
commit d123713c00
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 22 additions and 8 deletions

View File

@ -30,7 +30,7 @@ from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.gpu_utils import gpu_status_request
from onyx.utils.gpu_utils import fast_gpu_status_request
from onyx.utils.logger import setup_logger
logger = setup_logger()
@ -88,7 +88,9 @@ class Answer:
rerank_settings is not None
and rerank_settings.rerank_provider_type is not None
)
allow_agent_reranking = gpu_status_request() or using_cloud_reranking
allow_agent_reranking = (
fast_gpu_status_request(indexing=False) or using_cloud_reranking
)
# TODO: this is a hack to force the query to be used for the search tool
# this should be removed once we fully unify graph inputs (i.e.

View File

@ -324,7 +324,7 @@ def update_default_multipass_indexing(db_session: Session) -> None:
logger.info(
"No existing docs or connectors found. Checking GPU availability for multipass indexing."
)
gpu_available = gpu_status_request()
gpu_available = gpu_status_request(indexing=True)
logger.info(f"GPU available: {gpu_available}")
current_settings = get_current_search_settings(db_session)

View File

@ -1,3 +1,5 @@
from functools import lru_cache
import requests
from retry import retry
@ -10,8 +12,7 @@ from shared_configs.configs import MODEL_SERVER_PORT
logger = setup_logger()
@retry(tries=5, delay=5)
def gpu_status_request(indexing: bool = True) -> bool:
def _get_gpu_status_from_model_server(indexing: bool) -> bool:
if indexing:
model_server_url = f"{INDEXING_MODEL_SERVER_HOST}:{INDEXING_MODEL_SERVER_PORT}"
else:
@ -28,3 +29,14 @@ def gpu_status_request(indexing: bool = True) -> bool:
except requests.RequestException as e:
logger.error(f"Error: Unable to fetch GPU status. Error: {str(e)}")
raise # Re-raise exception to trigger a retry
@retry(tries=5, delay=5)
def gpu_status_request(indexing: bool) -> bool:
return _get_gpu_status_from_model_server(indexing)
@lru_cache(maxsize=1)
def fast_gpu_status_request(indexing: bool) -> bool:
"""For use in sync flows, where we don't want to retry / we want to cache this."""
return gpu_status_request(indexing=indexing)

View File

@ -50,7 +50,7 @@ def answer_instance(
mocker: MockerFixture,
) -> Answer:
mocker.patch(
"onyx.chat.answer.gpu_status_request",
"onyx.chat.answer.fast_gpu_status_request",
return_value=True,
)
return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config)
@ -400,7 +400,7 @@ def test_no_slow_reranking(
mocker: MockerFixture,
) -> None:
mocker.patch(
"onyx.chat.answer.gpu_status_request",
"onyx.chat.answer.fast_gpu_status_request",
return_value=gpu_enabled,
)
rerank_settings = (

View File

@ -39,7 +39,7 @@ def test_skip_gen_ai_answer_generation_flag(
mocker: MockerFixture,
) -> None:
mocker.patch(
"onyx.chat.answer.gpu_status_request",
"onyx.chat.answer.fast_gpu_status_request",
return_value=True,
)
question = config["question"]