Always Use Model Server (#1306)

This commit is contained in:
Yuhong Sun
2024-04-07 21:25:06 -07:00
committed by GitHub
parent 795243283d
commit 2db906b7a2
35 changed files with 724 additions and 550 deletions

View File

@ -0,0 +1 @@
MODEL_WARM_UP_STRING = "hi " * 512

View File

@ -1,19 +1,58 @@
import numpy as np
from fastapi import APIRouter
from typing import Optional
import numpy as np
import tensorflow as tf # type: ignore
from fastapi import APIRouter
from transformers import AutoTokenizer # type: ignore
from transformers import TFDistilBertForSequenceClassification
from model_server.constants import MODEL_WARM_UP_STRING
from model_server.utils import simple_log_function_time
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
from shared_configs.nlp_model_configs import INDEXING_ONLY
from shared_configs.nlp_model_configs import INTENT_MODEL_CONTEXT_SIZE
from shared_configs.nlp_model_configs import INTENT_MODEL_VERSION
from danswer.search.search_nlp_models import get_intent_model_tokenizer
from danswer.search.search_nlp_models import get_local_intent_model
from danswer.utils.timing import log_function_time
from shared_models.model_server_models import IntentRequest
from shared_models.model_server_models import IntentResponse
router = APIRouter(prefix="/custom")
_INTENT_TOKENIZER: Optional[AutoTokenizer] = None
_INTENT_MODEL: Optional[TFDistilBertForSequenceClassification] = None
@log_function_time(print_only=True)
def get_intent_model_tokenizer(
model_name: str = INTENT_MODEL_VERSION,
) -> "AutoTokenizer":
global _INTENT_TOKENIZER
if _INTENT_TOKENIZER is None:
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained(model_name)
return _INTENT_TOKENIZER
def get_local_intent_model(
model_name: str = INTENT_MODEL_VERSION,
max_context_length: int = INTENT_MODEL_CONTEXT_SIZE,
) -> TFDistilBertForSequenceClassification:
global _INTENT_MODEL
if _INTENT_MODEL is None or max_context_length != _INTENT_MODEL.max_seq_length:
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
model_name
)
_INTENT_MODEL.max_seq_length = max_context_length
return _INTENT_MODEL
def warm_up_intent_model() -> None:
intent_tokenizer = get_intent_model_tokenizer()
inputs = intent_tokenizer(
MODEL_WARM_UP_STRING, return_tensors="tf", truncation=True, padding=True
)
get_local_intent_model()(inputs)
@simple_log_function_time()
def classify_intent(query: str) -> list[float]:
import tensorflow as tf # type:ignore
tokenizer = get_intent_model_tokenizer()
intent_model = get_local_intent_model()
model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True)
@ -26,16 +65,11 @@ def classify_intent(query: str) -> list[float]:
@router.post("/intent-model")
def process_intent_request(
async def process_intent_request(
intent_request: IntentRequest,
) -> IntentResponse:
if INDEXING_ONLY:
raise RuntimeError("Indexing model server should not call intent endpoint")
class_percentages = classify_intent(intent_request.query)
return IntentResponse(class_probs=class_percentages)
def warm_up_intent_model() -> None:
intent_tokenizer = get_intent_model_tokenizer()
inputs = intent_tokenizer(
"danswer", return_tensors="tf", truncation=True, padding=True
)
get_local_intent_model()(inputs)

View File

@ -1,34 +1,33 @@
from typing import TYPE_CHECKING
import gc
from typing import Optional
from fastapi import APIRouter
from fastapi import HTTPException
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.search.search_nlp_models import get_local_reranking_model_ensemble
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
from shared_models.model_server_models import EmbedRequest
from shared_models.model_server_models import EmbedResponse
from shared_models.model_server_models import RerankRequest
from shared_models.model_server_models import RerankResponse
if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer # type: ignore
from model_server.constants import MODEL_WARM_UP_STRING
from model_server.utils import simple_log_function_time
from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import RerankRequest
from shared_configs.model_server_models import RerankResponse
from shared_configs.nlp_model_configs import CROSS_EMBED_CONTEXT_SIZE
from shared_configs.nlp_model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
from shared_configs.nlp_model_configs import INDEXING_ONLY
logger = setup_logger()
WARM_UP_STRING = "Danswer is amazing"
router = APIRouter(prefix="/encoder")
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
def get_embedding_model(
model_name: str,
max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
max_context_length: int,
) -> "SentenceTransformer":
from sentence_transformers import SentenceTransformer # type: ignore
@ -48,11 +47,44 @@ def get_embedding_model(
return _GLOBAL_MODELS_DICT[model_name]
@log_function_time(print_only=True)
def get_local_reranking_model_ensemble(
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
max_context_length: int = CROSS_EMBED_CONTEXT_SIZE,
) -> list[CrossEncoder]:
global _RERANK_MODELS
if _RERANK_MODELS is None or max_context_length != _RERANK_MODELS[0].max_length:
del _RERANK_MODELS
gc.collect()
_RERANK_MODELS = []
for model_name in model_names:
logger.info(f"Loading {model_name}")
model = CrossEncoder(model_name)
model.max_length = max_context_length
_RERANK_MODELS.append(model)
return _RERANK_MODELS
def warm_up_cross_encoders() -> None:
logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}")
cross_encoders = get_local_reranking_model_ensemble()
[
cross_encoder.predict((MODEL_WARM_UP_STRING, MODEL_WARM_UP_STRING))
for cross_encoder in cross_encoders
]
@simple_log_function_time()
def embed_text(
texts: list[str], model_name: str, normalize_embeddings: bool
texts: list[str],
model_name: str,
max_context_length: int,
normalize_embeddings: bool,
) -> list[list[float]]:
model = get_embedding_model(model_name=model_name)
model = get_embedding_model(
model_name=model_name, max_context_length=max_context_length
)
embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings)
if not isinstance(embeddings, list):
@ -61,7 +93,7 @@ def embed_text(
return embeddings
@log_function_time(print_only=True)
@simple_log_function_time()
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
cross_encoders = get_local_reranking_model_ensemble()
sim_scores = [
@ -72,13 +104,14 @@ def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
@router.post("/bi-encoder-embed")
def process_embed_request(
async def process_embed_request(
embed_request: EmbedRequest,
) -> EmbedResponse:
try:
embeddings = embed_text(
texts=embed_request.texts,
model_name=embed_request.model_name,
max_context_length=embed_request.max_context_length,
normalize_embeddings=embed_request.normalize_embeddings,
)
return EmbedResponse(embeddings=embeddings)
@ -87,7 +120,11 @@ def process_embed_request(
@router.post("/cross-encoder-scores")
def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
async def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
"""Cross encoders can be purely black box from the app perspective"""
if INDEXING_ONLY:
raise RuntimeError("Indexing model server should not call intent endpoint")
try:
sim_scores = calc_sim_scores(
query=embed_request.query, docs=embed_request.documents
@ -95,13 +132,3 @@ def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
return RerankResponse(scores=sim_scores)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def warm_up_cross_encoders() -> None:
logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}")
cross_encoders = get_local_reranking_model_ensemble()
[
cross_encoder.predict((WARM_UP_STRING, WARM_UP_STRING))
for cross_encoder in cross_encoders
]

View File

@ -1,40 +1,61 @@
import os
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
import torch
import uvicorn
from fastapi import FastAPI
from transformers import logging as transformer_logging # type:ignore
from danswer import __version__
from danswer.configs.app_configs import MODEL_SERVER_ALLOWED_HOST
from danswer.configs.app_configs import MODEL_SERVER_PORT
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
from danswer.utils.logger import setup_logger
from model_server.custom_models import router as custom_models_router
from model_server.custom_models import warm_up_intent_model
from model_server.encoders import router as encoders_router
from model_server.encoders import warm_up_cross_encoders
from shared_configs.nlp_model_configs import ENABLE_RERANKING_ASYNC_FLOW
from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
from shared_configs.nlp_model_configs import INDEXING_ONLY
from shared_configs.nlp_model_configs import MIN_THREADS_ML_MODELS
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
transformer_logging.set_verbosity_error()
logger = setup_logger()
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
if torch.cuda.is_available():
logger.info("GPU is available")
else:
logger.info("GPU is not available")
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
logger.info(f"Torch Threads: {torch.get_num_threads()}")
if not INDEXING_ONLY:
warm_up_intent_model()
if ENABLE_RERANKING_REAL_TIME_FLOW or ENABLE_RERANKING_ASYNC_FLOW:
warm_up_cross_encoders()
else:
logger.info("This model server should only run document indexing.")
yield
def get_model_app() -> FastAPI:
application = FastAPI(title="Danswer Model Server", version=__version__)
application = FastAPI(
title="Danswer Model Server", version=__version__, lifespan=lifespan
)
application.include_router(encoders_router)
application.include_router(custom_models_router)
@application.on_event("startup")
def startup_event() -> None:
if torch.cuda.is_available():
logger.info("GPU is available")
else:
logger.info("GPU is not available")
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
logger.info(f"Torch Threads: {torch.get_num_threads()}")
warm_up_cross_encoders()
warm_up_intent_model()
return application

View File

@ -0,0 +1,41 @@
import time
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from functools import wraps
from typing import Any
from typing import cast
from typing import TypeVar
from danswer.utils.logger import setup_logger
logger = setup_logger()
F = TypeVar("F", bound=Callable)
FG = TypeVar("FG", bound=Callable[..., Generator | Iterator])
def simple_log_function_time(
func_name: str | None = None,
debug_only: bool = False,
include_args: bool = False,
) -> Callable[[F], F]:
def decorator(func: F) -> F:
@wraps(func)
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
start_time = time.time()
result = func(*args, **kwargs)
elapsed_time_str = str(time.time() - start_time)
log_name = func_name or func.__name__
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
if debug_only:
logger.debug(final_log)
else:
logger.info(final_log)
return result
return cast(F, wrapped_func)
return decorator