improve gpu detection functions and logging in model server

This commit is contained in:
Richard Kuo (Danswer)
2025-02-07 16:59:02 -08:00
parent ae37f01f62
commit bc2c56dfb6
5 changed files with 54 additions and 17 deletions

View File

@@ -28,3 +28,9 @@ class EmbeddingModelTextType:
@staticmethod @staticmethod
def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str: def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str:
return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type] return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type]
class GPUStatus:
CUDA = "cuda"
MAC_MPS = "mps"
NONE = "none"

View File

@@ -12,6 +12,7 @@ import voyageai # type: ignore
from cohere import AsyncClient as CohereAsyncClient from cohere import AsyncClient as CohereAsyncClient
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import HTTPException from fastapi import HTTPException
from fastapi import Request
from google.oauth2 import service_account # type: ignore from google.oauth2 import service_account # type: ignore
from litellm import aembedding from litellm import aembedding
from litellm.exceptions import RateLimitError from litellm.exceptions import RateLimitError
@@ -320,6 +321,7 @@ async def embed_text(
prefix: str | None, prefix: str | None,
api_url: str | None, api_url: str | None,
api_version: str | None, api_version: str | None,
gpu_type: str = "UNKNOWN",
) -> list[Embedding]: ) -> list[Embedding]:
if not all(texts): if not all(texts):
logger.error("Empty strings provided for embedding") logger.error("Empty strings provided for embedding")
@@ -373,8 +375,11 @@ async def embed_text(
elapsed = time.monotonic() - start elapsed = time.monotonic() - start
logger.info( logger.info(
f"Successfully embedded {len(texts)} texts with {total_chars} total characters " f"event=embedding_provider "
f"with provider {provider_type} in {elapsed:.2f}" f"texts={len(texts)} "
f"chars={total_chars} "
f"provider={provider_type} "
f"elapsed={elapsed:.2f}"
) )
elif model_name is not None: elif model_name is not None:
logger.info( logger.info(
@@ -403,6 +408,14 @@ async def embed_text(
f"Successfully embedded {len(texts)} texts with {total_chars} total characters " f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
f"with local model {model_name} in {elapsed:.2f}" f"with local model {model_name} in {elapsed:.2f}"
) )
logger.info(
f"event=embedding_model "
f"texts={len(texts)} "
f"chars={total_chars} "
f"model={provider_type} "
f"gpu={gpu_type} "
f"elapsed={elapsed:.2f}"
)
else: else:
logger.error("Neither model name nor provider specified for embedding") logger.error("Neither model name nor provider specified for embedding")
raise ValueError( raise ValueError(
@@ -455,8 +468,15 @@ async def litellm_rerank(
@router.post("/bi-encoder-embed") @router.post("/bi-encoder-embed")
async def process_embed_request( async def route_bi_encoder_embed(
request: Request,
embed_request: EmbedRequest, embed_request: EmbedRequest,
) -> EmbedResponse:
return await process_embed_request(embed_request, request.app.state.gpu_type)
async def process_embed_request(
embed_request: EmbedRequest, gpu_type: str = "UNKNOWN"
) -> EmbedResponse: ) -> EmbedResponse:
if not embed_request.texts: if not embed_request.texts:
raise HTTPException(status_code=400, detail="No texts to be embedded") raise HTTPException(status_code=400, detail="No texts to be embedded")
@@ -484,6 +504,7 @@ async def process_embed_request(
api_url=embed_request.api_url, api_url=embed_request.api_url,
api_version=embed_request.api_version, api_version=embed_request.api_version,
prefix=prefix, prefix=prefix,
gpu_type=gpu_type,
) )
return EmbedResponse(embeddings=embeddings) return EmbedResponse(embeddings=embeddings)
except RateLimitError as e: except RateLimitError as e:

View File

@@ -16,6 +16,7 @@ from model_server.custom_models import router as custom_models_router
from model_server.custom_models import warm_up_intent_model from model_server.custom_models import warm_up_intent_model
from model_server.encoders import router as encoders_router from model_server.encoders import router as encoders_router
from model_server.management_endpoints import router as management_router from model_server.management_endpoints import router as management_router
from model_server.utils import get_gpu_type
from onyx import __version__ from onyx import __version__
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
from shared_configs.configs import INDEXING_ONLY from shared_configs.configs import INDEXING_ONLY
@@ -58,12 +59,10 @@ def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator: async def lifespan(app: FastAPI) -> AsyncGenerator:
if torch.cuda.is_available(): gpu_type = get_gpu_type()
logger.notice("CUDA GPU is available") logger.notice(f"gpu_type={gpu_type}")
elif torch.backends.mps.is_available():
logger.notice("Mac MPS is available") app.state.gpu_type = gpu_type
else:
logger.notice("GPU is not available, using CPU")
if TEMP_HF_CACHE_PATH.is_dir(): if TEMP_HF_CACHE_PATH.is_dir():
logger.notice("Moving contents of temp_huggingface to huggingface cache.") logger.notice("Moving contents of temp_huggingface to huggingface cache.")

View File

@@ -1,7 +1,9 @@
import torch
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import Response from fastapi import Response
from model_server.constants import GPUStatus
from model_server.utils import get_gpu_type
router = APIRouter(prefix="/api") router = APIRouter(prefix="/api")
@@ -11,10 +13,7 @@ async def healthcheck() -> Response:
@router.get("/gpu-status") @router.get("/gpu-status")
async def gpu_status() -> dict[str, bool | str]: async def route_gpu_status() -> dict[str, bool | str]:
if torch.cuda.is_available(): gpu_type = get_gpu_type()
return {"gpu_available": True, "type": "cuda"} gpu_available = gpu_type != GPUStatus.NONE
elif torch.backends.mps.is_available(): return {"gpu_available": gpu_available, "type": gpu_type}
return {"gpu_available": True, "type": "mps"}
else:
return {"gpu_available": False, "type": "none"}

View File

@@ -8,6 +8,9 @@ from typing import Any
from typing import cast from typing import cast
from typing import TypeVar from typing import TypeVar
import torch
from model_server.constants import GPUStatus
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
@@ -58,3 +61,12 @@ def simple_log_function_time(
return cast(F, wrapped_sync_func) return cast(F, wrapped_sync_func)
return decorator return decorator
def get_gpu_type() -> str:
if torch.cuda.is_available():
return GPUStatus.CUDA
if torch.backends.mps.is_available():
return GPUStatus.MAC_MPS
return GPUStatus.NONE