mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-18 19:43:26 +02:00
Model Server (#695)
Provides the ability to pull out the NLP models into a separate model server which can then be hosted on a GPU instance if desired.
This commit is contained in:
0
backend/model_server/__init__.py
Normal file
0
backend/model_server/__init__.py
Normal file
40
backend/model_server/custom_models.py
Normal file
40
backend/model_server/custom_models.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf # type:ignore
|
||||
from fastapi import APIRouter
|
||||
|
||||
from danswer.search.search_nlp_models import get_intent_model_tokenizer
|
||||
from danswer.search.search_nlp_models import get_local_intent_model
|
||||
from danswer.utils.timing import log_function_time
|
||||
from shared_models.model_server_models import IntentRequest
|
||||
from shared_models.model_server_models import IntentResponse
|
||||
|
||||
router = APIRouter(prefix="/custom")
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def classify_intent(query: str) -> list[float]:
|
||||
tokenizer = get_intent_model_tokenizer()
|
||||
intent_model = get_local_intent_model()
|
||||
model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True)
|
||||
|
||||
predictions = intent_model(model_input)[0]
|
||||
probabilities = tf.nn.softmax(predictions, axis=-1)
|
||||
|
||||
class_percentages = np.round(probabilities.numpy() * 100, 2)
|
||||
return list(class_percentages.tolist()[0])
|
||||
|
||||
|
||||
@router.post("/intent-model")
|
||||
def process_intent_request(
|
||||
intent_request: IntentRequest,
|
||||
) -> IntentResponse:
|
||||
class_percentages = classify_intent(intent_request.query)
|
||||
return IntentResponse(class_probs=class_percentages)
|
||||
|
||||
|
||||
def warm_up_intent_model() -> None:
|
||||
intent_tokenizer = get_intent_model_tokenizer()
|
||||
inputs = intent_tokenizer(
|
||||
"danswer", return_tensors="tf", truncation=True, padding=True
|
||||
)
|
||||
get_local_intent_model()(inputs)
|
81
backend/model_server/encoders.py
Normal file
81
backend/model_server/encoders.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from danswer.search.search_nlp_models import get_local_embedding_model
|
||||
from danswer.search.search_nlp_models import get_local_reranking_model_ensemble
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
from shared_models.model_server_models import EmbedRequest
|
||||
from shared_models.model_server_models import EmbedResponse
|
||||
from shared_models.model_server_models import RerankRequest
|
||||
from shared_models.model_server_models import RerankResponse
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
WARM_UP_STRING = "Danswer is amazing"
|
||||
|
||||
router = APIRouter(prefix="/encoder")
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def embed_text(
|
||||
texts: list[str],
|
||||
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
|
||||
) -> list[list[float]]:
|
||||
model = get_local_embedding_model()
|
||||
embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings)
|
||||
|
||||
if not isinstance(embeddings, list):
|
||||
embeddings = embeddings.tolist()
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
|
||||
cross_encoders = get_local_reranking_model_ensemble()
|
||||
sim_scores = [
|
||||
encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
|
||||
for encoder in cross_encoders
|
||||
]
|
||||
return sim_scores
|
||||
|
||||
|
||||
@router.post("/bi-encoder-embed")
|
||||
def process_embed_request(
|
||||
embed_request: EmbedRequest,
|
||||
) -> EmbedResponse:
|
||||
try:
|
||||
embeddings = embed_text(texts=embed_request.texts)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/cross-encoder-scores")
|
||||
def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
|
||||
try:
|
||||
sim_scores = calc_sim_scores(
|
||||
query=embed_request.query, docs=embed_request.documents
|
||||
)
|
||||
return RerankResponse(scores=sim_scores)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def warm_up_bi_encoder() -> None:
|
||||
logger.info(f"Warming up Bi-Encoders: {DOCUMENT_ENCODER_MODEL}")
|
||||
get_local_embedding_model().encode(WARM_UP_STRING)
|
||||
|
||||
|
||||
def warm_up_cross_encoders() -> None:
|
||||
logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}")
|
||||
|
||||
cross_encoders = get_local_reranking_model_ensemble()
|
||||
[
|
||||
cross_encoder.predict((WARM_UP_STRING, WARM_UP_STRING))
|
||||
for cross_encoder in cross_encoders
|
||||
]
|
51
backend/model_server/main.py
Normal file
51
backend/model_server/main.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
|
||||
from danswer import __version__
|
||||
from danswer.configs.app_configs import MODEL_SERVER_ALLOWED_HOST
|
||||
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
||||
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
|
||||
from danswer.utils.logger import setup_logger
|
||||
from model_server.custom_models import router as custom_models_router
|
||||
from model_server.custom_models import warm_up_intent_model
|
||||
from model_server.encoders import router as encoders_router
|
||||
from model_server.encoders import warm_up_bi_encoder
|
||||
from model_server.encoders import warm_up_cross_encoders
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_model_app() -> FastAPI:
|
||||
application = FastAPI(title="Danswer Model Server", version=__version__)
|
||||
|
||||
application.include_router(encoders_router)
|
||||
application.include_router(custom_models_router)
|
||||
|
||||
@application.on_event("startup")
|
||||
def startup_event() -> None:
|
||||
if torch.cuda.is_available():
|
||||
logger.info("GPU is available")
|
||||
else:
|
||||
logger.info("GPU is not available")
|
||||
|
||||
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
|
||||
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
warm_up_bi_encoder()
|
||||
warm_up_cross_encoders()
|
||||
warm_up_intent_model()
|
||||
|
||||
return application
|
||||
|
||||
|
||||
app = get_model_app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info(
|
||||
f"Starting Danswer Model Server on http://{MODEL_SERVER_ALLOWED_HOST}:{str(MODEL_SERVER_PORT)}/"
|
||||
)
|
||||
logger.info(f"Model Server Version: {__version__}")
|
||||
uvicorn.run(app, host=MODEL_SERVER_ALLOWED_HOST, port=MODEL_SERVER_PORT)
|
Reference in New Issue
Block a user