mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-13 21:30:21 +02:00
91 lines
2.9 KiB
Python
91 lines
2.9 KiB
Python
import asyncio
|
|
import os
|
|
import shutil
|
|
from collections.abc import AsyncGenerator
|
|
from contextlib import asynccontextmanager
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import uvicorn
|
|
from fastapi import FastAPI
|
|
from transformers import logging as transformer_logging # type:ignore
|
|
|
|
from danswer import __version__
|
|
from danswer.utils.logger import setup_logger
|
|
from model_server.custom_models import router as custom_models_router
|
|
from model_server.custom_models import warm_up_intent_model
|
|
from model_server.encoders import router as encoders_router
|
|
from model_server.management_endpoints import router as management_router
|
|
from shared_configs.configs import INDEXING_ONLY
|
|
from shared_configs.configs import MIN_THREADS_ML_MODELS
|
|
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
|
|
from shared_configs.configs import MODEL_SERVER_PORT
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
|
|
|
transformer_logging.set_verbosity_error()
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
async def manage_huggingface_cache() -> None:
|
|
temp_hf_cache = Path("/root/.cache/temp_huggingface")
|
|
hf_cache = Path("/root/.cache/huggingface")
|
|
if temp_hf_cache.is_dir() and any(temp_hf_cache.iterdir()):
|
|
hf_cache.mkdir(parents=True, exist_ok=True)
|
|
for item in temp_hf_cache.iterdir():
|
|
if item.is_dir():
|
|
await asyncio.to_thread(
|
|
shutil.copytree, item, hf_cache / item.name, dirs_exist_ok=True
|
|
)
|
|
else:
|
|
await asyncio.to_thread(shutil.copy2, item, hf_cache)
|
|
await asyncio.to_thread(shutil.rmtree, temp_hf_cache)
|
|
logger.info("Copied contents of temp_huggingface and deleted the directory.")
|
|
else:
|
|
logger.info("Source directory is empty or does not exist. Skipping copy.")
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
|
if torch.cuda.is_available():
|
|
logger.info("GPU is available")
|
|
else:
|
|
logger.info("GPU is not available")
|
|
|
|
await manage_huggingface_cache()
|
|
|
|
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
|
|
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
|
|
|
if not INDEXING_ONLY:
|
|
warm_up_intent_model()
|
|
else:
|
|
logger.info("This model server should only run document indexing.")
|
|
|
|
yield
|
|
|
|
|
|
def get_model_app() -> FastAPI:
|
|
application = FastAPI(
|
|
title="Danswer Model Server", version=__version__, lifespan=lifespan
|
|
)
|
|
|
|
application.include_router(management_router)
|
|
application.include_router(encoders_router)
|
|
application.include_router(custom_models_router)
|
|
|
|
return application
|
|
|
|
|
|
app = get_model_app()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logger.info(
|
|
f"Starting Danswer Model Server on http://{MODEL_SERVER_ALLOWED_HOST}:{str(MODEL_SERVER_PORT)}/"
|
|
)
|
|
logger.info(f"Model Server Version: {__version__}")
|
|
uvicorn.run(app, host=MODEL_SERVER_ALLOWED_HOST, port=MODEL_SERVER_PORT)
|