From bc2c56dfb6431cdc0a6952cf1990f55e783d2a03 Mon Sep 17 00:00:00 2001 From: "Richard Kuo (Danswer)" Date: Fri, 7 Feb 2025 16:59:02 -0800 Subject: [PATCH] improve gpu detection functions and logging in model server --- backend/model_server/constants.py | 6 +++++ backend/model_server/encoders.py | 27 +++++++++++++++++--- backend/model_server/main.py | 11 ++++---- backend/model_server/management_endpoints.py | 15 +++++------ backend/model_server/utils.py | 12 +++++++++ 5 files changed, 54 insertions(+), 17 deletions(-) diff --git a/backend/model_server/constants.py b/backend/model_server/constants.py index d6991b402..fac57cb73 100644 --- a/backend/model_server/constants.py +++ b/backend/model_server/constants.py @@ -28,3 +28,9 @@ class EmbeddingModelTextType: @staticmethod def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str: return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type] + + +class GPUStatus: + CUDA = "cuda" + MAC_MPS = "mps" + NONE = "none" diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index 3ed3857f6..502eeecef 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -12,6 +12,7 @@ import voyageai # type: ignore from cohere import AsyncClient as CohereAsyncClient from fastapi import APIRouter from fastapi import HTTPException +from fastapi import Request from google.oauth2 import service_account # type: ignore from litellm import aembedding from litellm.exceptions import RateLimitError @@ -320,6 +321,7 @@ async def embed_text( prefix: str | None, api_url: str | None, api_version: str | None, + gpu_type: str = "UNKNOWN", ) -> list[Embedding]: if not all(texts): logger.error("Empty strings provided for embedding") @@ -373,8 +375,11 @@ async def embed_text( elapsed = time.monotonic() - start logger.info( - f"Successfully embedded {len(texts)} texts with {total_chars} total characters " - f"with provider {provider_type} in {elapsed:.2f}" + f"event=embedding_provider " + f"texts={len(texts)} " + f"chars={total_chars} " + f"provider={provider_type} " + f"elapsed={elapsed:.2f}" ) elif model_name is not None: logger.info( @@ -403,6 +408,14 @@ async def embed_text( f"Successfully embedded {len(texts)} texts with {total_chars} total characters " f"with local model {model_name} in {elapsed:.2f}" ) + logger.info( + f"event=embedding_model " + f"texts={len(texts)} " + f"chars={total_chars} " + f"model={provider_type} " + f"gpu={gpu_type} " + f"elapsed={elapsed:.2f}" + ) else: logger.error("Neither model name nor provider specified for embedding") raise ValueError( @@ -455,8 +468,15 @@ async def litellm_rerank( @router.post("/bi-encoder-embed") -async def process_embed_request( +async def route_bi_encoder_embed( + request: Request, embed_request: EmbedRequest, +) -> EmbedResponse: + return await process_embed_request(embed_request, request.app.state.gpu_type) + + +async def process_embed_request( + embed_request: EmbedRequest, gpu_type: str = "UNKNOWN" ) -> EmbedResponse: if not embed_request.texts: raise HTTPException(status_code=400, detail="No texts to be embedded") @@ -484,6 +504,7 @@ async def process_embed_request( api_url=embed_request.api_url, api_version=embed_request.api_version, prefix=prefix, + gpu_type=gpu_type, ) return EmbedResponse(embeddings=embeddings) except RateLimitError as e: diff --git a/backend/model_server/main.py b/backend/model_server/main.py index 9a9cda7fd..2031d69ea 100644 --- a/backend/model_server/main.py +++ b/backend/model_server/main.py @@ -16,6 +16,7 @@ from model_server.custom_models import router as custom_models_router from model_server.custom_models import warm_up_intent_model from model_server.encoders import router as encoders_router from model_server.management_endpoints import router as management_router +from model_server.utils import get_gpu_type from onyx import __version__ from onyx.utils.logger import setup_logger from shared_configs.configs import INDEXING_ONLY @@ -58,12 +59,10 @@ def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) - @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator: - if torch.cuda.is_available(): - logger.notice("CUDA GPU is available") - elif torch.backends.mps.is_available(): - logger.notice("Mac MPS is available") - else: - logger.notice("GPU is not available, using CPU") + gpu_type = get_gpu_type() + logger.notice(f"gpu_type={gpu_type}") + + app.state.gpu_type = gpu_type if TEMP_HF_CACHE_PATH.is_dir(): logger.notice("Moving contents of temp_huggingface to huggingface cache.") diff --git a/backend/model_server/management_endpoints.py b/backend/model_server/management_endpoints.py index 4c6387e07..a722911ae 100644 --- a/backend/model_server/management_endpoints.py +++ b/backend/model_server/management_endpoints.py @@ -1,7 +1,9 @@ -import torch from fastapi import APIRouter from fastapi import Response +from model_server.constants import GPUStatus +from model_server.utils import get_gpu_type + router = APIRouter(prefix="/api") @@ -11,10 +13,7 @@ async def healthcheck() -> Response: @router.get("/gpu-status") -async def gpu_status() -> dict[str, bool | str]: - if torch.cuda.is_available(): - return {"gpu_available": True, "type": "cuda"} - elif torch.backends.mps.is_available(): - return {"gpu_available": True, "type": "mps"} - else: - return {"gpu_available": False, "type": "none"} +async def route_gpu_status() -> dict[str, bool | str]: + gpu_type = get_gpu_type() + gpu_available = gpu_type != GPUStatus.NONE + return {"gpu_available": gpu_available, "type": gpu_type} diff --git a/backend/model_server/utils.py b/backend/model_server/utils.py index 3580da684..b53431fda 100644 --- a/backend/model_server/utils.py +++ b/backend/model_server/utils.py @@ -8,6 +8,9 @@ from typing import Any from typing import cast from typing import TypeVar +import torch + +from model_server.constants import GPUStatus from onyx.utils.logger import setup_logger logger = setup_logger() @@ -58,3 +61,12 @@ def simple_log_function_time( return cast(F, wrapped_sync_func) return decorator + + +def get_gpu_type() -> str: + if torch.cuda.is_available(): + return GPUStatus.CUDA + if torch.backends.mps.is_available(): + return GPUStatus.MAC_MPS + + return GPUStatus.NONE