mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-07 05:30:40 +02:00
Provides the ability to pull out the NLP models into a separate model server which can then be hosted on a GPU instance if desired.
82 lines
2.6 KiB
Python
82 lines
2.6 KiB
Python
from fastapi import APIRouter
|
|
from fastapi import HTTPException
|
|
|
|
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
|
|
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
|
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
|
from danswer.search.search_nlp_models import get_local_embedding_model
|
|
from danswer.search.search_nlp_models import get_local_reranking_model_ensemble
|
|
from danswer.utils.logger import setup_logger
|
|
from danswer.utils.timing import log_function_time
|
|
from shared_models.model_server_models import EmbedRequest
|
|
from shared_models.model_server_models import EmbedResponse
|
|
from shared_models.model_server_models import RerankRequest
|
|
from shared_models.model_server_models import RerankResponse
|
|
|
|
logger = setup_logger()
|
|
|
|
WARM_UP_STRING = "Danswer is amazing"
|
|
|
|
router = APIRouter(prefix="/encoder")
|
|
|
|
|
|
@log_function_time()
|
|
def embed_text(
|
|
texts: list[str],
|
|
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
|
|
) -> list[list[float]]:
|
|
model = get_local_embedding_model()
|
|
embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings)
|
|
|
|
if not isinstance(embeddings, list):
|
|
embeddings = embeddings.tolist()
|
|
|
|
return embeddings
|
|
|
|
|
|
@log_function_time()
|
|
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
|
|
cross_encoders = get_local_reranking_model_ensemble()
|
|
sim_scores = [
|
|
encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
|
|
for encoder in cross_encoders
|
|
]
|
|
return sim_scores
|
|
|
|
|
|
@router.post("/bi-encoder-embed")
|
|
def process_embed_request(
|
|
embed_request: EmbedRequest,
|
|
) -> EmbedResponse:
|
|
try:
|
|
embeddings = embed_text(texts=embed_request.texts)
|
|
return EmbedResponse(embeddings=embeddings)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post("/cross-encoder-scores")
|
|
def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
|
|
try:
|
|
sim_scores = calc_sim_scores(
|
|
query=embed_request.query, docs=embed_request.documents
|
|
)
|
|
return RerankResponse(scores=sim_scores)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
def warm_up_bi_encoder() -> None:
|
|
logger.info(f"Warming up Bi-Encoders: {DOCUMENT_ENCODER_MODEL}")
|
|
get_local_embedding_model().encode(WARM_UP_STRING)
|
|
|
|
|
|
def warm_up_cross_encoders() -> None:
|
|
logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}")
|
|
|
|
cross_encoders = get_local_reranking_model_ensemble()
|
|
[
|
|
cross_encoder.predict((WARM_UP_STRING, WARM_UP_STRING))
|
|
for cross_encoder in cross_encoders
|
|
]
|