mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-12 06:05:43 +02:00
Always Use Model Server (#1306)
This commit is contained in:
1
backend/model_server/constants.py
Normal file
1
backend/model_server/constants.py
Normal file
@ -0,0 +1 @@
|
||||
MODEL_WARM_UP_STRING = "hi " * 512
|
@ -1,19 +1,58 @@
|
||||
import numpy as np
|
||||
from fastapi import APIRouter
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore
|
||||
from fastapi import APIRouter
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
from transformers import TFDistilBertForSequenceClassification
|
||||
|
||||
from model_server.constants import MODEL_WARM_UP_STRING
|
||||
from model_server.utils import simple_log_function_time
|
||||
from shared_configs.model_server_models import IntentRequest
|
||||
from shared_configs.model_server_models import IntentResponse
|
||||
from shared_configs.nlp_model_configs import INDEXING_ONLY
|
||||
from shared_configs.nlp_model_configs import INTENT_MODEL_CONTEXT_SIZE
|
||||
from shared_configs.nlp_model_configs import INTENT_MODEL_VERSION
|
||||
|
||||
from danswer.search.search_nlp_models import get_intent_model_tokenizer
|
||||
from danswer.search.search_nlp_models import get_local_intent_model
|
||||
from danswer.utils.timing import log_function_time
|
||||
from shared_models.model_server_models import IntentRequest
|
||||
from shared_models.model_server_models import IntentResponse
|
||||
|
||||
router = APIRouter(prefix="/custom")
|
||||
|
||||
_INTENT_TOKENIZER: Optional[AutoTokenizer] = None
|
||||
_INTENT_MODEL: Optional[TFDistilBertForSequenceClassification] = None
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
|
||||
def get_intent_model_tokenizer(
|
||||
model_name: str = INTENT_MODEL_VERSION,
|
||||
) -> "AutoTokenizer":
|
||||
global _INTENT_TOKENIZER
|
||||
if _INTENT_TOKENIZER is None:
|
||||
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained(model_name)
|
||||
return _INTENT_TOKENIZER
|
||||
|
||||
|
||||
def get_local_intent_model(
|
||||
model_name: str = INTENT_MODEL_VERSION,
|
||||
max_context_length: int = INTENT_MODEL_CONTEXT_SIZE,
|
||||
) -> TFDistilBertForSequenceClassification:
|
||||
global _INTENT_MODEL
|
||||
if _INTENT_MODEL is None or max_context_length != _INTENT_MODEL.max_seq_length:
|
||||
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
|
||||
model_name
|
||||
)
|
||||
_INTENT_MODEL.max_seq_length = max_context_length
|
||||
return _INTENT_MODEL
|
||||
|
||||
|
||||
def warm_up_intent_model() -> None:
|
||||
intent_tokenizer = get_intent_model_tokenizer()
|
||||
inputs = intent_tokenizer(
|
||||
MODEL_WARM_UP_STRING, return_tensors="tf", truncation=True, padding=True
|
||||
)
|
||||
get_local_intent_model()(inputs)
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def classify_intent(query: str) -> list[float]:
|
||||
import tensorflow as tf # type:ignore
|
||||
|
||||
tokenizer = get_intent_model_tokenizer()
|
||||
intent_model = get_local_intent_model()
|
||||
model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True)
|
||||
@ -26,16 +65,11 @@ def classify_intent(query: str) -> list[float]:
|
||||
|
||||
|
||||
@router.post("/intent-model")
|
||||
def process_intent_request(
|
||||
async def process_intent_request(
|
||||
intent_request: IntentRequest,
|
||||
) -> IntentResponse:
|
||||
if INDEXING_ONLY:
|
||||
raise RuntimeError("Indexing model server should not call intent endpoint")
|
||||
|
||||
class_percentages = classify_intent(intent_request.query)
|
||||
return IntentResponse(class_probs=class_percentages)
|
||||
|
||||
|
||||
def warm_up_intent_model() -> None:
|
||||
intent_tokenizer = get_intent_model_tokenizer()
|
||||
inputs = intent_tokenizer(
|
||||
"danswer", return_tensors="tf", truncation=True, padding=True
|
||||
)
|
||||
get_local_intent_model()(inputs)
|
||||
|
@ -1,34 +1,33 @@
|
||||
from typing import TYPE_CHECKING
|
||||
import gc
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.search.search_nlp_models import get_local_reranking_model_ensemble
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
from shared_models.model_server_models import EmbedRequest
|
||||
from shared_models.model_server_models import EmbedResponse
|
||||
from shared_models.model_server_models import RerankRequest
|
||||
from shared_models.model_server_models import RerankResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
from model_server.constants import MODEL_WARM_UP_STRING
|
||||
from model_server.utils import simple_log_function_time
|
||||
from shared_configs.model_server_models import EmbedRequest
|
||||
from shared_configs.model_server_models import EmbedResponse
|
||||
from shared_configs.model_server_models import RerankRequest
|
||||
from shared_configs.model_server_models import RerankResponse
|
||||
from shared_configs.nlp_model_configs import CROSS_EMBED_CONTEXT_SIZE
|
||||
from shared_configs.nlp_model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
|
||||
from shared_configs.nlp_model_configs import INDEXING_ONLY
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
WARM_UP_STRING = "Danswer is amazing"
|
||||
|
||||
router = APIRouter(prefix="/encoder")
|
||||
|
||||
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
|
||||
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
|
||||
|
||||
|
||||
def get_embedding_model(
|
||||
model_name: str,
|
||||
max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
max_context_length: int,
|
||||
) -> "SentenceTransformer":
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
@ -48,11 +47,44 @@ def get_embedding_model(
|
||||
return _GLOBAL_MODELS_DICT[model_name]
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def get_local_reranking_model_ensemble(
|
||||
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
|
||||
max_context_length: int = CROSS_EMBED_CONTEXT_SIZE,
|
||||
) -> list[CrossEncoder]:
|
||||
global _RERANK_MODELS
|
||||
if _RERANK_MODELS is None or max_context_length != _RERANK_MODELS[0].max_length:
|
||||
del _RERANK_MODELS
|
||||
gc.collect()
|
||||
|
||||
_RERANK_MODELS = []
|
||||
for model_name in model_names:
|
||||
logger.info(f"Loading {model_name}")
|
||||
model = CrossEncoder(model_name)
|
||||
model.max_length = max_context_length
|
||||
_RERANK_MODELS.append(model)
|
||||
return _RERANK_MODELS
|
||||
|
||||
|
||||
def warm_up_cross_encoders() -> None:
|
||||
logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}")
|
||||
|
||||
cross_encoders = get_local_reranking_model_ensemble()
|
||||
[
|
||||
cross_encoder.predict((MODEL_WARM_UP_STRING, MODEL_WARM_UP_STRING))
|
||||
for cross_encoder in cross_encoders
|
||||
]
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def embed_text(
|
||||
texts: list[str], model_name: str, normalize_embeddings: bool
|
||||
texts: list[str],
|
||||
model_name: str,
|
||||
max_context_length: int,
|
||||
normalize_embeddings: bool,
|
||||
) -> list[list[float]]:
|
||||
model = get_embedding_model(model_name=model_name)
|
||||
model = get_embedding_model(
|
||||
model_name=model_name, max_context_length=max_context_length
|
||||
)
|
||||
embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings)
|
||||
|
||||
if not isinstance(embeddings, list):
|
||||
@ -61,7 +93,7 @@ def embed_text(
|
||||
return embeddings
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
@simple_log_function_time()
|
||||
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
|
||||
cross_encoders = get_local_reranking_model_ensemble()
|
||||
sim_scores = [
|
||||
@ -72,13 +104,14 @@ def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
|
||||
|
||||
|
||||
@router.post("/bi-encoder-embed")
|
||||
def process_embed_request(
|
||||
async def process_embed_request(
|
||||
embed_request: EmbedRequest,
|
||||
) -> EmbedResponse:
|
||||
try:
|
||||
embeddings = embed_text(
|
||||
texts=embed_request.texts,
|
||||
model_name=embed_request.model_name,
|
||||
max_context_length=embed_request.max_context_length,
|
||||
normalize_embeddings=embed_request.normalize_embeddings,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
@ -87,7 +120,11 @@ def process_embed_request(
|
||||
|
||||
|
||||
@router.post("/cross-encoder-scores")
|
||||
def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
|
||||
async def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
|
||||
"""Cross encoders can be purely black box from the app perspective"""
|
||||
if INDEXING_ONLY:
|
||||
raise RuntimeError("Indexing model server should not call intent endpoint")
|
||||
|
||||
try:
|
||||
sim_scores = calc_sim_scores(
|
||||
query=embed_request.query, docs=embed_request.documents
|
||||
@ -95,13 +132,3 @@ def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
|
||||
return RerankResponse(scores=sim_scores)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def warm_up_cross_encoders() -> None:
|
||||
logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}")
|
||||
|
||||
cross_encoders = get_local_reranking_model_ensemble()
|
||||
[
|
||||
cross_encoder.predict((WARM_UP_STRING, WARM_UP_STRING))
|
||||
for cross_encoder in cross_encoders
|
||||
]
|
||||
|
@ -1,40 +1,61 @@
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
|
||||
from danswer import __version__
|
||||
from danswer.configs.app_configs import MODEL_SERVER_ALLOWED_HOST
|
||||
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
||||
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
|
||||
from danswer.utils.logger import setup_logger
|
||||
from model_server.custom_models import router as custom_models_router
|
||||
from model_server.custom_models import warm_up_intent_model
|
||||
from model_server.encoders import router as encoders_router
|
||||
from model_server.encoders import warm_up_cross_encoders
|
||||
from shared_configs.nlp_model_configs import ENABLE_RERANKING_ASYNC_FLOW
|
||||
from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||
from shared_configs.nlp_model_configs import INDEXING_ONLY
|
||||
from shared_configs.nlp_model_configs import MIN_THREADS_ML_MODELS
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
|
||||
transformer_logging.set_verbosity_error()
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
if torch.cuda.is_available():
|
||||
logger.info("GPU is available")
|
||||
else:
|
||||
logger.info("GPU is not available")
|
||||
|
||||
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
|
||||
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
if not INDEXING_ONLY:
|
||||
warm_up_intent_model()
|
||||
if ENABLE_RERANKING_REAL_TIME_FLOW or ENABLE_RERANKING_ASYNC_FLOW:
|
||||
warm_up_cross_encoders()
|
||||
else:
|
||||
logger.info("This model server should only run document indexing.")
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def get_model_app() -> FastAPI:
|
||||
application = FastAPI(title="Danswer Model Server", version=__version__)
|
||||
application = FastAPI(
|
||||
title="Danswer Model Server", version=__version__, lifespan=lifespan
|
||||
)
|
||||
|
||||
application.include_router(encoders_router)
|
||||
application.include_router(custom_models_router)
|
||||
|
||||
@application.on_event("startup")
|
||||
def startup_event() -> None:
|
||||
if torch.cuda.is_available():
|
||||
logger.info("GPU is available")
|
||||
else:
|
||||
logger.info("GPU is not available")
|
||||
|
||||
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
|
||||
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
warm_up_cross_encoders()
|
||||
warm_up_intent_model()
|
||||
|
||||
return application
|
||||
|
||||
|
||||
|
41
backend/model_server/utils.py
Normal file
41
backend/model_server/utils.py
Normal file
@ -0,0 +1,41 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
F = TypeVar("F", bound=Callable)
|
||||
FG = TypeVar("FG", bound=Callable[..., Generator | Iterator])
|
||||
|
||||
|
||||
def simple_log_function_time(
|
||||
func_name: str | None = None,
|
||||
debug_only: bool = False,
|
||||
include_args: bool = False,
|
||||
) -> Callable[[F], F]:
|
||||
def decorator(func: F) -> F:
|
||||
@wraps(func)
|
||||
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
elapsed_time_str = str(time.time() - start_time)
|
||||
log_name = func_name or func.__name__
|
||||
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
|
||||
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
|
||||
if debug_only:
|
||||
logger.debug(final_log)
|
||||
else:
|
||||
logger.info(final_log)
|
||||
|
||||
return result
|
||||
|
||||
return cast(F, wrapped_func)
|
||||
|
||||
return decorator
|
Reference in New Issue
Block a user