NLP Model Warmup Reworked (#748)

This commit is contained in:
Yuhong Sun 2023-11-20 17:28:23 -08:00 committed by GitHub
parent 6e9f31d1e9
commit 57f0323f52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 38 additions and 43 deletions

View File

@ -13,7 +13,7 @@ from danswer.utils.timing import log_function_time
@log_function_time()
def encode_chunks(
def embed_chunks(
chunks: list[DocAwareChunk],
embedding_model: SentenceTransformer | None = None,
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
@ -67,4 +67,4 @@ def encode_chunks(
class DefaultEmbedder(Embedder):
def embed(self, chunks: list[DocAwareChunk]) -> list[IndexChunk]:
return encode_chunks(chunks)
return embed_chunks(chunks)

View File

@ -200,14 +200,14 @@ def get_application() -> FastAPI:
)
else:
logger.info("Warming up local NLP models.")
warm_up_models()
if torch.cuda.is_available():
logger.info("GPU is available")
else:
logger.info("GPU is not available")
logger.info(f"Torch Threads: {torch.get_num_threads()}")
warm_up_models()
# This is for the LLM, most LLMs will not need warming up
get_default_llm().log_model_configs()
get_default_qa_model().warm_up_model()

View File

@ -46,7 +46,8 @@ def get_local_embedding_model(
max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
) -> SentenceTransformer:
global _EMBED_MODEL
if _EMBED_MODEL is None:
if _EMBED_MODEL is None or max_context_length != _EMBED_MODEL.max_seq_length:
logger.info(f"Loading {model_name}")
_EMBED_MODEL = SentenceTransformer(model_name)
_EMBED_MODEL.max_seq_length = max_context_length
return _EMBED_MODEL
@ -57,10 +58,13 @@ def get_local_reranking_model_ensemble(
max_context_length: int = CROSS_EMBED_CONTEXT_SIZE,
) -> list[CrossEncoder]:
global _RERANK_MODELS
if _RERANK_MODELS is None:
_RERANK_MODELS = [CrossEncoder(model_name) for model_name in model_names]
for model in _RERANK_MODELS:
if _RERANK_MODELS is None or max_context_length != _RERANK_MODELS[0].max_length:
_RERANK_MODELS = []
for model_name in model_names:
logger.info(f"Loading {model_name}")
model = CrossEncoder(model_name)
model.max_length = max_context_length
_RERANK_MODELS.append(model)
return _RERANK_MODELS
@ -76,7 +80,7 @@ def get_local_intent_model(
max_context_length: int = QUERY_MAX_CONTEXT_SIZE,
) -> TFDistilBertForSequenceClassification:
global _INTENT_MODEL
if _INTENT_MODEL is None:
if _INTENT_MODEL is None or max_context_length != _INTENT_MODEL.max_seq_length:
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
model_name
)
@ -84,30 +88,6 @@ def get_local_intent_model(
return _INTENT_MODEL
def warm_up_models(
indexer_only: bool = False, skip_cross_encoders: bool = SKIP_RERANKING
) -> None:
warm_up_str = "Danswer is amazing"
get_default_tokenizer()(warm_up_str)
get_local_embedding_model().encode(warm_up_str)
if indexer_only:
return
if not skip_cross_encoders:
cross_encoders = get_local_reranking_model_ensemble()
[
cross_encoder.predict((warm_up_str, warm_up_str))
for cross_encoder in cross_encoders
]
intent_tokenizer = get_intent_model_tokenizer()
inputs = intent_tokenizer(
warm_up_str, return_tensors="tf", truncation=True, padding=True
)
get_local_intent_model()(inputs)
class EmbeddingModel:
def __init__(
self,
@ -269,3 +249,27 @@ class IntentModel:
class_percentages = np.round(probabilities.numpy() * 100, 2)
return list(class_percentages.tolist()[0])
def warm_up_models(
indexer_only: bool = False, skip_cross_encoders: bool = SKIP_RERANKING
) -> None:
warm_up_str = (
"Danswer is amazing! Check out our easy deployment guide at "
"https://docs.danswer.dev/quickstart"
)
get_default_tokenizer()(warm_up_str)
EmbeddingModel().encode(texts=[warm_up_str])
if indexer_only:
return
if not skip_cross_encoders:
CrossEncoderEnsembleModel().predict(query=warm_up_str, passages=[warm_up_str])
intent_tokenizer = get_intent_model_tokenizer()
inputs = intent_tokenizer(
warm_up_str, return_tensors="tf", truncation=True, padding=True
)
get_local_intent_model()(inputs)

View File

@ -7,7 +7,6 @@ import numpy
from nltk.corpus import stopwords # type:ignore
from nltk.stem import WordNetLemmatizer # type:ignore
from nltk.tokenize import word_tokenize # type:ignore
from sentence_transformers import SentenceTransformer # type: ignore
from sqlalchemy.orm import Session
from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER
@ -17,7 +16,6 @@ from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH
from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW
from danswer.db.feedback import create_query_event
@ -75,17 +73,10 @@ def query_processing(
def embed_query(
query: str,
embedding_model: SentenceTransformer | None = None,
prefix: str = ASYM_QUERY_PREFIX,
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
) -> list[float]:
model = embedding_model or EmbeddingModel()
prefixed_query = prefix + query
query_embedding = model.encode(
[prefixed_query], normalize_embeddings=normalize_embeddings
)[0]
return query_embedding
return EmbeddingModel().encode([prefixed_query])[0]
def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]: