Yuhong Sun 7433dddac3
Model Server (#695)
Provides the ability to pull out the NLP models into a separate model server which can then be hosted on a GPU instance if desired.
2023-11-06 16:36:09 -08:00

82 lines
2.6 KiB
Python

from fastapi import APIRouter
from fastapi import HTTPException
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.search.search_nlp_models import get_local_embedding_model
from danswer.search.search_nlp_models import get_local_reranking_model_ensemble
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
from shared_models.model_server_models import EmbedRequest
from shared_models.model_server_models import EmbedResponse
from shared_models.model_server_models import RerankRequest
from shared_models.model_server_models import RerankResponse
logger = setup_logger()
WARM_UP_STRING = "Danswer is amazing"
router = APIRouter(prefix="/encoder")
@log_function_time()
def embed_text(
texts: list[str],
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
) -> list[list[float]]:
model = get_local_embedding_model()
embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings)
if not isinstance(embeddings, list):
embeddings = embeddings.tolist()
return embeddings
@log_function_time()
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
cross_encoders = get_local_reranking_model_ensemble()
sim_scores = [
encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
for encoder in cross_encoders
]
return sim_scores
@router.post("/bi-encoder-embed")
def process_embed_request(
embed_request: EmbedRequest,
) -> EmbedResponse:
try:
embeddings = embed_text(texts=embed_request.texts)
return EmbedResponse(embeddings=embeddings)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/cross-encoder-scores")
def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
try:
sim_scores = calc_sim_scores(
query=embed_request.query, docs=embed_request.documents
)
return RerankResponse(scores=sim_scores)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def warm_up_bi_encoder() -> None:
logger.info(f"Warming up Bi-Encoders: {DOCUMENT_ENCODER_MODEL}")
get_local_embedding_model().encode(WARM_UP_STRING)
def warm_up_cross_encoders() -> None:
logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}")
cross_encoders = get_local_reranking_model_ensemble()
[
cross_encoder.predict((WARM_UP_STRING, WARM_UP_STRING))
for cross_encoder in cross_encoders
]