mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-28 13:53:28 +02:00
Background Index Attempt Creation (#1010)
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from danswer.search.search_nlp_models import get_local_embedding_model
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.search.search_nlp_models import get_local_reranking_model_ensemble
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
@@ -13,19 +13,46 @@ from shared_models.model_server_models import EmbedResponse
|
||||
from shared_models.model_server_models import RerankRequest
|
||||
from shared_models.model_server_models import RerankResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
WARM_UP_STRING = "Danswer is amazing"
|
||||
|
||||
router = APIRouter(prefix="/encoder")
|
||||
|
||||
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
|
||||
|
||||
|
||||
def get_embedding_model(
|
||||
model_name: str,
|
||||
max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
) -> "SentenceTransformer":
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
global _GLOBAL_MODELS_DICT # A dictionary to store models
|
||||
|
||||
if _GLOBAL_MODELS_DICT is None:
|
||||
_GLOBAL_MODELS_DICT = {}
|
||||
|
||||
if model_name not in _GLOBAL_MODELS_DICT:
|
||||
logger.info(f"Loading {model_name}")
|
||||
model = SentenceTransformer(model_name)
|
||||
model.max_seq_length = max_context_length
|
||||
_GLOBAL_MODELS_DICT[model_name] = model
|
||||
elif max_context_length != _GLOBAL_MODELS_DICT[model_name].max_seq_length:
|
||||
_GLOBAL_MODELS_DICT[model_name].max_seq_length = max_context_length
|
||||
|
||||
return _GLOBAL_MODELS_DICT[model_name]
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def embed_text(
|
||||
texts: list[str],
|
||||
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
|
||||
texts: list[str], model_name: str, normalize_embeddings: bool
|
||||
) -> list[list[float]]:
|
||||
model = get_local_embedding_model()
|
||||
model = get_embedding_model(model_name=model_name)
|
||||
embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings)
|
||||
|
||||
if not isinstance(embeddings, list):
|
||||
@@ -49,7 +76,11 @@ def process_embed_request(
|
||||
embed_request: EmbedRequest,
|
||||
) -> EmbedResponse:
|
||||
try:
|
||||
embeddings = embed_text(texts=embed_request.texts)
|
||||
embeddings = embed_text(
|
||||
texts=embed_request.texts,
|
||||
model_name=embed_request.model_name,
|
||||
normalize_embeddings=embed_request.normalize_embeddings,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -66,11 +97,6 @@ def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def warm_up_bi_encoder() -> None:
|
||||
logger.info(f"Warming up Bi-Encoders: {DOCUMENT_ENCODER_MODEL}")
|
||||
get_local_embedding_model().encode(WARM_UP_STRING)
|
||||
|
||||
|
||||
def warm_up_cross_encoders() -> None:
|
||||
logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}")
|
||||
|
||||
|
@@ -10,7 +10,6 @@ from danswer.utils.logger import setup_logger
|
||||
from model_server.custom_models import router as custom_models_router
|
||||
from model_server.custom_models import warm_up_intent_model
|
||||
from model_server.encoders import router as encoders_router
|
||||
from model_server.encoders import warm_up_bi_encoder
|
||||
from model_server.encoders import warm_up_cross_encoders
|
||||
|
||||
|
||||
@@ -33,7 +32,6 @@ def get_model_app() -> FastAPI:
|
||||
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
|
||||
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
warm_up_bi_encoder()
|
||||
warm_up_cross_encoders()
|
||||
warm_up_intent_model()
|
||||
|
||||
|
Reference in New Issue
Block a user