From 57f0323f5285e774feebc91f34d54370c8e5db79 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Mon, 20 Nov 2023 17:28:23 -0800 Subject: [PATCH] NLP Model Warmup Reworked (#748) --- backend/danswer/indexing/embedder.py | 4 +- backend/danswer/main.py | 4 +- backend/danswer/search/search_nlp_models.py | 62 +++++++++++---------- backend/danswer/search/search_runner.py | 11 +--- 4 files changed, 38 insertions(+), 43 deletions(-) diff --git a/backend/danswer/indexing/embedder.py b/backend/danswer/indexing/embedder.py index 8d224aa10..d2b57477c 100644 --- a/backend/danswer/indexing/embedder.py +++ b/backend/danswer/indexing/embedder.py @@ -13,7 +13,7 @@ from danswer.utils.timing import log_function_time @log_function_time() -def encode_chunks( +def embed_chunks( chunks: list[DocAwareChunk], embedding_model: SentenceTransformer | None = None, batch_size: int = BATCH_SIZE_ENCODE_CHUNKS, @@ -67,4 +67,4 @@ def encode_chunks( class DefaultEmbedder(Embedder): def embed(self, chunks: list[DocAwareChunk]) -> list[IndexChunk]: - return encode_chunks(chunks) + return embed_chunks(chunks) diff --git a/backend/danswer/main.py b/backend/danswer/main.py index cf51b5c28..5d94a5ee2 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -200,14 +200,14 @@ def get_application() -> FastAPI: ) else: logger.info("Warming up local NLP models.") + warm_up_models() + if torch.cuda.is_available(): logger.info("GPU is available") else: logger.info("GPU is not available") logger.info(f"Torch Threads: {torch.get_num_threads()}") - warm_up_models() - # This is for the LLM, most LLMs will not need warming up get_default_llm().log_model_configs() get_default_qa_model().warm_up_model() diff --git a/backend/danswer/search/search_nlp_models.py b/backend/danswer/search/search_nlp_models.py index e6d8a292b..20d86567e 100644 --- a/backend/danswer/search/search_nlp_models.py +++ b/backend/danswer/search/search_nlp_models.py @@ -46,7 +46,8 @@ def get_local_embedding_model( max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE, ) -> SentenceTransformer: global _EMBED_MODEL - if _EMBED_MODEL is None: + if _EMBED_MODEL is None or max_context_length != _EMBED_MODEL.max_seq_length: + logger.info(f"Loading {model_name}") _EMBED_MODEL = SentenceTransformer(model_name) _EMBED_MODEL.max_seq_length = max_context_length return _EMBED_MODEL @@ -57,10 +58,13 @@ def get_local_reranking_model_ensemble( max_context_length: int = CROSS_EMBED_CONTEXT_SIZE, ) -> list[CrossEncoder]: global _RERANK_MODELS - if _RERANK_MODELS is None: - _RERANK_MODELS = [CrossEncoder(model_name) for model_name in model_names] - for model in _RERANK_MODELS: + if _RERANK_MODELS is None or max_context_length != _RERANK_MODELS[0].max_length: + _RERANK_MODELS = [] + for model_name in model_names: + logger.info(f"Loading {model_name}") + model = CrossEncoder(model_name) model.max_length = max_context_length + _RERANK_MODELS.append(model) return _RERANK_MODELS @@ -76,7 +80,7 @@ def get_local_intent_model( max_context_length: int = QUERY_MAX_CONTEXT_SIZE, ) -> TFDistilBertForSequenceClassification: global _INTENT_MODEL - if _INTENT_MODEL is None: + if _INTENT_MODEL is None or max_context_length != _INTENT_MODEL.max_seq_length: _INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained( model_name ) @@ -84,30 +88,6 @@ def get_local_intent_model( return _INTENT_MODEL -def warm_up_models( - indexer_only: bool = False, skip_cross_encoders: bool = SKIP_RERANKING -) -> None: - warm_up_str = "Danswer is amazing" - get_default_tokenizer()(warm_up_str) - get_local_embedding_model().encode(warm_up_str) - - if indexer_only: - return - - if not skip_cross_encoders: - cross_encoders = get_local_reranking_model_ensemble() - [ - cross_encoder.predict((warm_up_str, warm_up_str)) - for cross_encoder in cross_encoders - ] - - intent_tokenizer = get_intent_model_tokenizer() - inputs = intent_tokenizer( - warm_up_str, return_tensors="tf", truncation=True, padding=True - ) - get_local_intent_model()(inputs) - - class EmbeddingModel: def __init__( self, @@ -269,3 +249,27 @@ class IntentModel: class_percentages = np.round(probabilities.numpy() * 100, 2) return list(class_percentages.tolist()[0]) + + +def warm_up_models( + indexer_only: bool = False, skip_cross_encoders: bool = SKIP_RERANKING +) -> None: + warm_up_str = ( + "Danswer is amazing! Check out our easy deployment guide at " + "https://docs.danswer.dev/quickstart" + ) + get_default_tokenizer()(warm_up_str) + + EmbeddingModel().encode(texts=[warm_up_str]) + + if indexer_only: + return + + if not skip_cross_encoders: + CrossEncoderEnsembleModel().predict(query=warm_up_str, passages=[warm_up_str]) + + intent_tokenizer = get_intent_model_tokenizer() + inputs = intent_tokenizer( + warm_up_str, return_tensors="tf", truncation=True, padding=True + ) + get_local_intent_model()(inputs) diff --git a/backend/danswer/search/search_runner.py b/backend/danswer/search/search_runner.py index 807eaadb6..92e640c9d 100644 --- a/backend/danswer/search/search_runner.py +++ b/backend/danswer/search/search_runner.py @@ -7,7 +7,6 @@ import numpy from nltk.corpus import stopwords # type:ignore from nltk.stem import WordNetLemmatizer # type:ignore from nltk.tokenize import word_tokenize # type:ignore -from sentence_transformers import SentenceTransformer # type: ignore from sqlalchemy.orm import Session from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER @@ -17,7 +16,6 @@ from danswer.configs.app_configs import NUM_RERANKED_RESULTS from danswer.configs.model_configs import ASYM_QUERY_PREFIX from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN -from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW from danswer.db.feedback import create_query_event @@ -75,17 +73,10 @@ def query_processing( def embed_query( query: str, - embedding_model: SentenceTransformer | None = None, prefix: str = ASYM_QUERY_PREFIX, - normalize_embeddings: bool = NORMALIZE_EMBEDDINGS, ) -> list[float]: - model = embedding_model or EmbeddingModel() prefixed_query = prefix + query - query_embedding = model.encode( - [prefixed_query], normalize_embeddings=normalize_embeddings - )[0] - - return query_embedding + return EmbeddingModel().encode([prefixed_query])[0] def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]: