mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-08-08 14:02:09 +02:00
improve gpu detection functions and logging in model server
This commit is contained in:
@@ -28,3 +28,9 @@ class EmbeddingModelTextType:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str:
|
def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str:
|
||||||
return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type]
|
return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type]
|
||||||
|
|
||||||
|
|
||||||
|
class GPUStatus:
|
||||||
|
CUDA = "cuda"
|
||||||
|
MAC_MPS = "mps"
|
||||||
|
NONE = "none"
|
||||||
|
@@ -12,6 +12,7 @@ import voyageai # type: ignore
|
|||||||
from cohere import AsyncClient as CohereAsyncClient
|
from cohere import AsyncClient as CohereAsyncClient
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
from fastapi import Request
|
||||||
from google.oauth2 import service_account # type: ignore
|
from google.oauth2 import service_account # type: ignore
|
||||||
from litellm import aembedding
|
from litellm import aembedding
|
||||||
from litellm.exceptions import RateLimitError
|
from litellm.exceptions import RateLimitError
|
||||||
@@ -320,6 +321,7 @@ async def embed_text(
|
|||||||
prefix: str | None,
|
prefix: str | None,
|
||||||
api_url: str | None,
|
api_url: str | None,
|
||||||
api_version: str | None,
|
api_version: str | None,
|
||||||
|
gpu_type: str = "UNKNOWN",
|
||||||
) -> list[Embedding]:
|
) -> list[Embedding]:
|
||||||
if not all(texts):
|
if not all(texts):
|
||||||
logger.error("Empty strings provided for embedding")
|
logger.error("Empty strings provided for embedding")
|
||||||
@@ -373,8 +375,11 @@ async def embed_text(
|
|||||||
|
|
||||||
elapsed = time.monotonic() - start
|
elapsed = time.monotonic() - start
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
|
f"event=embedding_provider "
|
||||||
f"with provider {provider_type} in {elapsed:.2f}"
|
f"texts={len(texts)} "
|
||||||
|
f"chars={total_chars} "
|
||||||
|
f"provider={provider_type} "
|
||||||
|
f"elapsed={elapsed:.2f}"
|
||||||
)
|
)
|
||||||
elif model_name is not None:
|
elif model_name is not None:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -403,6 +408,14 @@ async def embed_text(
|
|||||||
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
|
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
|
||||||
f"with local model {model_name} in {elapsed:.2f}"
|
f"with local model {model_name} in {elapsed:.2f}"
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
f"event=embedding_model "
|
||||||
|
f"texts={len(texts)} "
|
||||||
|
f"chars={total_chars} "
|
||||||
|
f"model={provider_type} "
|
||||||
|
f"gpu={gpu_type} "
|
||||||
|
f"elapsed={elapsed:.2f}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.error("Neither model name nor provider specified for embedding")
|
logger.error("Neither model name nor provider specified for embedding")
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -455,8 +468,15 @@ async def litellm_rerank(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/bi-encoder-embed")
|
@router.post("/bi-encoder-embed")
|
||||||
async def process_embed_request(
|
async def route_bi_encoder_embed(
|
||||||
|
request: Request,
|
||||||
embed_request: EmbedRequest,
|
embed_request: EmbedRequest,
|
||||||
|
) -> EmbedResponse:
|
||||||
|
return await process_embed_request(embed_request, request.app.state.gpu_type)
|
||||||
|
|
||||||
|
|
||||||
|
async def process_embed_request(
|
||||||
|
embed_request: EmbedRequest, gpu_type: str = "UNKNOWN"
|
||||||
) -> EmbedResponse:
|
) -> EmbedResponse:
|
||||||
if not embed_request.texts:
|
if not embed_request.texts:
|
||||||
raise HTTPException(status_code=400, detail="No texts to be embedded")
|
raise HTTPException(status_code=400, detail="No texts to be embedded")
|
||||||
@@ -484,6 +504,7 @@ async def process_embed_request(
|
|||||||
api_url=embed_request.api_url,
|
api_url=embed_request.api_url,
|
||||||
api_version=embed_request.api_version,
|
api_version=embed_request.api_version,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
|
gpu_type=gpu_type,
|
||||||
)
|
)
|
||||||
return EmbedResponse(embeddings=embeddings)
|
return EmbedResponse(embeddings=embeddings)
|
||||||
except RateLimitError as e:
|
except RateLimitError as e:
|
||||||
|
@@ -16,6 +16,7 @@ from model_server.custom_models import router as custom_models_router
|
|||||||
from model_server.custom_models import warm_up_intent_model
|
from model_server.custom_models import warm_up_intent_model
|
||||||
from model_server.encoders import router as encoders_router
|
from model_server.encoders import router as encoders_router
|
||||||
from model_server.management_endpoints import router as management_router
|
from model_server.management_endpoints import router as management_router
|
||||||
|
from model_server.utils import get_gpu_type
|
||||||
from onyx import __version__
|
from onyx import __version__
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
from shared_configs.configs import INDEXING_ONLY
|
from shared_configs.configs import INDEXING_ONLY
|
||||||
@@ -58,12 +59,10 @@ def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||||
if torch.cuda.is_available():
|
gpu_type = get_gpu_type()
|
||||||
logger.notice("CUDA GPU is available")
|
logger.notice(f"gpu_type={gpu_type}")
|
||||||
elif torch.backends.mps.is_available():
|
|
||||||
logger.notice("Mac MPS is available")
|
app.state.gpu_type = gpu_type
|
||||||
else:
|
|
||||||
logger.notice("GPU is not available, using CPU")
|
|
||||||
|
|
||||||
if TEMP_HF_CACHE_PATH.is_dir():
|
if TEMP_HF_CACHE_PATH.is_dir():
|
||||||
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
|
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
|
||||||
|
@@ -1,7 +1,9 @@
|
|||||||
import torch
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi import Response
|
from fastapi import Response
|
||||||
|
|
||||||
|
from model_server.constants import GPUStatus
|
||||||
|
from model_server.utils import get_gpu_type
|
||||||
|
|
||||||
router = APIRouter(prefix="/api")
|
router = APIRouter(prefix="/api")
|
||||||
|
|
||||||
|
|
||||||
@@ -11,10 +13,7 @@ async def healthcheck() -> Response:
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/gpu-status")
|
@router.get("/gpu-status")
|
||||||
async def gpu_status() -> dict[str, bool | str]:
|
async def route_gpu_status() -> dict[str, bool | str]:
|
||||||
if torch.cuda.is_available():
|
gpu_type = get_gpu_type()
|
||||||
return {"gpu_available": True, "type": "cuda"}
|
gpu_available = gpu_type != GPUStatus.NONE
|
||||||
elif torch.backends.mps.is_available():
|
return {"gpu_available": gpu_available, "type": gpu_type}
|
||||||
return {"gpu_available": True, "type": "mps"}
|
|
||||||
else:
|
|
||||||
return {"gpu_available": False, "type": "none"}
|
|
||||||
|
@@ -8,6 +8,9 @@ from typing import Any
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from model_server.constants import GPUStatus
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@@ -58,3 +61,12 @@ def simple_log_function_time(
|
|||||||
return cast(F, wrapped_sync_func)
|
return cast(F, wrapped_sync_func)
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def get_gpu_type() -> str:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return GPUStatus.CUDA
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
return GPUStatus.MAC_MPS
|
||||||
|
|
||||||
|
return GPUStatus.NONE
|
||||||
|
Reference in New Issue
Block a user