Background Index Attempt Creation (#1010)

This commit is contained in:
Yuhong Sun
2024-01-28 23:14:20 -08:00
committed by GitHub
parent c0c9c67534
commit 4b45164496
35 changed files with 1022 additions and 370 deletions

View File

@@ -1,10 +1,10 @@
from typing import TYPE_CHECKING
from fastapi import APIRouter
from fastapi import HTTPException
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.search.search_nlp_models import get_local_embedding_model
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.search.search_nlp_models import get_local_reranking_model_ensemble
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
@@ -13,19 +13,46 @@ from shared_models.model_server_models import EmbedResponse
from shared_models.model_server_models import RerankRequest
from shared_models.model_server_models import RerankResponse
if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer # type: ignore
logger = setup_logger()
WARM_UP_STRING = "Danswer is amazing"
router = APIRouter(prefix="/encoder")
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
def get_embedding_model(
model_name: str,
max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
) -> "SentenceTransformer":
from sentence_transformers import SentenceTransformer # type: ignore
global _GLOBAL_MODELS_DICT # A dictionary to store models
if _GLOBAL_MODELS_DICT is None:
_GLOBAL_MODELS_DICT = {}
if model_name not in _GLOBAL_MODELS_DICT:
logger.info(f"Loading {model_name}")
model = SentenceTransformer(model_name)
model.max_seq_length = max_context_length
_GLOBAL_MODELS_DICT[model_name] = model
elif max_context_length != _GLOBAL_MODELS_DICT[model_name].max_seq_length:
_GLOBAL_MODELS_DICT[model_name].max_seq_length = max_context_length
return _GLOBAL_MODELS_DICT[model_name]
@log_function_time(print_only=True)
def embed_text(
texts: list[str],
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
texts: list[str], model_name: str, normalize_embeddings: bool
) -> list[list[float]]:
model = get_local_embedding_model()
model = get_embedding_model(model_name=model_name)
embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings)
if not isinstance(embeddings, list):
@@ -49,7 +76,11 @@ def process_embed_request(
embed_request: EmbedRequest,
) -> EmbedResponse:
try:
embeddings = embed_text(texts=embed_request.texts)
embeddings = embed_text(
texts=embed_request.texts,
model_name=embed_request.model_name,
normalize_embeddings=embed_request.normalize_embeddings,
)
return EmbedResponse(embeddings=embeddings)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@@ -66,11 +97,6 @@ def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
raise HTTPException(status_code=500, detail=str(e))
def warm_up_bi_encoder() -> None:
logger.info(f"Warming up Bi-Encoders: {DOCUMENT_ENCODER_MODEL}")
get_local_embedding_model().encode(WARM_UP_STRING)
def warm_up_cross_encoders() -> None:
logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}")

View File

@@ -10,7 +10,6 @@ from danswer.utils.logger import setup_logger
from model_server.custom_models import router as custom_models_router
from model_server.custom_models import warm_up_intent_model
from model_server.encoders import router as encoders_router
from model_server.encoders import warm_up_bi_encoder
from model_server.encoders import warm_up_cross_encoders
@@ -33,7 +32,6 @@ def get_model_app() -> FastAPI:
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
logger.info(f"Torch Threads: {torch.get_num_threads()}")
warm_up_bi_encoder()
warm_up_cross_encoders()
warm_up_intent_model()