121 lines
4.2 KiB
Python
Raw Permalink Normal View History

2024-04-07 21:25:06 -07:00
import os
import shutil
2024-04-07 21:25:06 -07:00
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from pathlib import Path
2024-04-07 21:25:06 -07:00
import sentry_sdk
import torch
import uvicorn
from fastapi import FastAPI
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.starlette import StarletteIntegration
2024-04-07 21:25:06 -07:00
from transformers import logging as transformer_logging # type:ignore
from model_server.custom_models import router as custom_models_router
from model_server.custom_models import warm_up_information_content_model
from model_server.custom_models import warm_up_intent_model
from model_server.encoders import router as encoders_router
2024-04-18 16:22:38 -07:00
from model_server.management_endpoints import router as management_router
from model_server.utils import get_gpu_type
2024-12-13 09:48:43 -08:00
from onyx import __version__
from onyx.utils.logger import setup_logger
2024-04-10 23:13:22 -07:00
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import MIN_THREADS_ML_MODELS
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import SENTRY_DSN
2024-04-07 21:25:06 -07:00
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/huggingface"
TEMP_HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/temp_huggingface"
2024-04-07 21:25:06 -07:00
transformer_logging.set_verbosity_error()
logger = setup_logger()
def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -> None:
"""
This moves the files from the temp huggingface cache to the huggingface cache
We have to move each file individually because the directories might
have the same name but not the same contents and we dont want to remove
the files in the existing huggingface cache that don't exist in the temp
huggingface cache.
"""
for item in source.iterdir():
target_path = dest / item.relative_to(source)
if item.is_dir():
_move_files_recursively(item, target_path, overwrite)
else:
target_path.parent.mkdir(parents=True, exist_ok=True)
if target_path.exists() and not overwrite:
continue
shutil.move(str(item), str(target_path))
2024-04-07 21:25:06 -07:00
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
gpu_type = get_gpu_type()
2025-02-07 17:28:17 -08:00
logger.notice(f"Torch GPU Detection: gpu_type={gpu_type}")
app.state.gpu_type = gpu_type
if TEMP_HF_CACHE_PATH.is_dir():
2024-08-18 21:53:40 -07:00
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
_move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH)
shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True)
2024-08-18 21:53:40 -07:00
logger.notice("Moved contents of temp_huggingface to huggingface cache.")
2024-04-07 21:25:06 -07:00
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
2024-08-18 21:53:40 -07:00
logger.notice(f"Torch Threads: {torch.get_num_threads()}")
2024-04-07 21:25:06 -07:00
if not INDEXING_ONLY:
logger.notice(
"The intent model should run on the model server. The information content model should not run here."
)
2024-04-07 21:25:06 -07:00
warm_up_intent_model()
else:
logger.notice(
"The content information model should run on the indexing model server. The intent model should not run here."
)
warm_up_information_content_model()
2024-04-07 21:25:06 -07:00
yield
2024-04-07 21:25:06 -07:00
def get_model_app() -> FastAPI:
application = FastAPI(
2024-12-13 09:48:43 -08:00
title="Onyx Model Server", version=__version__, lifespan=lifespan
2024-04-07 21:25:06 -07:00
)
if SENTRY_DSN:
sentry_sdk.init(
dsn=SENTRY_DSN,
integrations=[StarletteIntegration(), FastApiIntegration()],
2024-10-26 12:06:46 -07:00
traces_sample_rate=0.1,
)
logger.info("Sentry initialized")
else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
2024-04-07 21:25:06 -07:00
2024-04-18 16:22:38 -07:00
application.include_router(management_router)
2024-04-07 21:25:06 -07:00
application.include_router(encoders_router)
application.include_router(custom_models_router)
return application
app = get_model_app()
if __name__ == "__main__":
2024-08-18 21:53:40 -07:00
logger.notice(
2024-12-13 09:48:43 -08:00
f"Starting Onyx Model Server on http://{MODEL_SERVER_ALLOWED_HOST}:{str(MODEL_SERVER_PORT)}/"
)
2024-08-18 21:53:40 -07:00
logger.notice(f"Model Server Version: {__version__}")
uvicorn.run(app, host=MODEL_SERVER_ALLOWED_HOST, port=MODEL_SERVER_PORT)