mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-07 02:10:30 +02:00
73 lines
2.3 KiB
Python
73 lines
2.3 KiB
Python
import os
|
|
from collections.abc import AsyncGenerator
|
|
from contextlib import asynccontextmanager
|
|
|
|
import torch
|
|
import uvicorn
|
|
from fastapi import FastAPI
|
|
from transformers import logging as transformer_logging # type:ignore
|
|
|
|
from danswer import __version__
|
|
from danswer.utils.logger import setup_logger
|
|
from model_server.custom_models import router as custom_models_router
|
|
from model_server.custom_models import warm_up_intent_model
|
|
from model_server.encoders import router as encoders_router
|
|
from model_server.encoders import warm_up_cross_encoders
|
|
from model_server.management_endpoints import router as management_router
|
|
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
|
|
from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
|
from shared_configs.configs import INDEXING_ONLY
|
|
from shared_configs.configs import MIN_THREADS_ML_MODELS
|
|
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
|
|
from shared_configs.configs import MODEL_SERVER_PORT
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
|
|
|
transformer_logging.set_verbosity_error()
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
|
if torch.cuda.is_available():
|
|
logger.info("GPU is available")
|
|
else:
|
|
logger.info("GPU is not available")
|
|
|
|
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
|
|
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
|
|
|
if not INDEXING_ONLY:
|
|
warm_up_intent_model()
|
|
if ENABLE_RERANKING_REAL_TIME_FLOW or ENABLE_RERANKING_ASYNC_FLOW:
|
|
warm_up_cross_encoders()
|
|
else:
|
|
logger.info("This model server should only run document indexing.")
|
|
|
|
yield
|
|
|
|
|
|
def get_model_app() -> FastAPI:
|
|
application = FastAPI(
|
|
title="Danswer Model Server", version=__version__, lifespan=lifespan
|
|
)
|
|
|
|
application.include_router(management_router)
|
|
application.include_router(encoders_router)
|
|
application.include_router(custom_models_router)
|
|
|
|
return application
|
|
|
|
|
|
app = get_model_app()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logger.info(
|
|
f"Starting Danswer Model Server on http://{MODEL_SERVER_ALLOWED_HOST}:{str(MODEL_SERVER_PORT)}/"
|
|
)
|
|
logger.info(f"Model Server Version: {__version__}")
|
|
uvicorn.run(app, host=MODEL_SERVER_ALLOWED_HOST, port=MODEL_SERVER_PORT)
|