mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-01 02:30:18 +02:00
NLP Model Warmup Reworked (#748)
This commit is contained in:
parent
6e9f31d1e9
commit
57f0323f52
@ -13,7 +13,7 @@ from danswer.utils.timing import log_function_time
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def encode_chunks(
|
||||
def embed_chunks(
|
||||
chunks: list[DocAwareChunk],
|
||||
embedding_model: SentenceTransformer | None = None,
|
||||
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||
@ -67,4 +67,4 @@ def encode_chunks(
|
||||
|
||||
class DefaultEmbedder(Embedder):
|
||||
def embed(self, chunks: list[DocAwareChunk]) -> list[IndexChunk]:
|
||||
return encode_chunks(chunks)
|
||||
return embed_chunks(chunks)
|
||||
|
@ -200,14 +200,14 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
else:
|
||||
logger.info("Warming up local NLP models.")
|
||||
warm_up_models()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
logger.info("GPU is available")
|
||||
else:
|
||||
logger.info("GPU is not available")
|
||||
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
warm_up_models()
|
||||
|
||||
# This is for the LLM, most LLMs will not need warming up
|
||||
get_default_llm().log_model_configs()
|
||||
get_default_qa_model().warm_up_model()
|
||||
|
@ -46,7 +46,8 @@ def get_local_embedding_model(
|
||||
max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
) -> SentenceTransformer:
|
||||
global _EMBED_MODEL
|
||||
if _EMBED_MODEL is None:
|
||||
if _EMBED_MODEL is None or max_context_length != _EMBED_MODEL.max_seq_length:
|
||||
logger.info(f"Loading {model_name}")
|
||||
_EMBED_MODEL = SentenceTransformer(model_name)
|
||||
_EMBED_MODEL.max_seq_length = max_context_length
|
||||
return _EMBED_MODEL
|
||||
@ -57,10 +58,13 @@ def get_local_reranking_model_ensemble(
|
||||
max_context_length: int = CROSS_EMBED_CONTEXT_SIZE,
|
||||
) -> list[CrossEncoder]:
|
||||
global _RERANK_MODELS
|
||||
if _RERANK_MODELS is None:
|
||||
_RERANK_MODELS = [CrossEncoder(model_name) for model_name in model_names]
|
||||
for model in _RERANK_MODELS:
|
||||
if _RERANK_MODELS is None or max_context_length != _RERANK_MODELS[0].max_length:
|
||||
_RERANK_MODELS = []
|
||||
for model_name in model_names:
|
||||
logger.info(f"Loading {model_name}")
|
||||
model = CrossEncoder(model_name)
|
||||
model.max_length = max_context_length
|
||||
_RERANK_MODELS.append(model)
|
||||
return _RERANK_MODELS
|
||||
|
||||
|
||||
@ -76,7 +80,7 @@ def get_local_intent_model(
|
||||
max_context_length: int = QUERY_MAX_CONTEXT_SIZE,
|
||||
) -> TFDistilBertForSequenceClassification:
|
||||
global _INTENT_MODEL
|
||||
if _INTENT_MODEL is None:
|
||||
if _INTENT_MODEL is None or max_context_length != _INTENT_MODEL.max_seq_length:
|
||||
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
|
||||
model_name
|
||||
)
|
||||
@ -84,30 +88,6 @@ def get_local_intent_model(
|
||||
return _INTENT_MODEL
|
||||
|
||||
|
||||
def warm_up_models(
|
||||
indexer_only: bool = False, skip_cross_encoders: bool = SKIP_RERANKING
|
||||
) -> None:
|
||||
warm_up_str = "Danswer is amazing"
|
||||
get_default_tokenizer()(warm_up_str)
|
||||
get_local_embedding_model().encode(warm_up_str)
|
||||
|
||||
if indexer_only:
|
||||
return
|
||||
|
||||
if not skip_cross_encoders:
|
||||
cross_encoders = get_local_reranking_model_ensemble()
|
||||
[
|
||||
cross_encoder.predict((warm_up_str, warm_up_str))
|
||||
for cross_encoder in cross_encoders
|
||||
]
|
||||
|
||||
intent_tokenizer = get_intent_model_tokenizer()
|
||||
inputs = intent_tokenizer(
|
||||
warm_up_str, return_tensors="tf", truncation=True, padding=True
|
||||
)
|
||||
get_local_intent_model()(inputs)
|
||||
|
||||
|
||||
class EmbeddingModel:
|
||||
def __init__(
|
||||
self,
|
||||
@ -269,3 +249,27 @@ class IntentModel:
|
||||
class_percentages = np.round(probabilities.numpy() * 100, 2)
|
||||
|
||||
return list(class_percentages.tolist()[0])
|
||||
|
||||
|
||||
def warm_up_models(
|
||||
indexer_only: bool = False, skip_cross_encoders: bool = SKIP_RERANKING
|
||||
) -> None:
|
||||
warm_up_str = (
|
||||
"Danswer is amazing! Check out our easy deployment guide at "
|
||||
"https://docs.danswer.dev/quickstart"
|
||||
)
|
||||
get_default_tokenizer()(warm_up_str)
|
||||
|
||||
EmbeddingModel().encode(texts=[warm_up_str])
|
||||
|
||||
if indexer_only:
|
||||
return
|
||||
|
||||
if not skip_cross_encoders:
|
||||
CrossEncoderEnsembleModel().predict(query=warm_up_str, passages=[warm_up_str])
|
||||
|
||||
intent_tokenizer = get_intent_model_tokenizer()
|
||||
inputs = intent_tokenizer(
|
||||
warm_up_str, return_tensors="tf", truncation=True, padding=True
|
||||
)
|
||||
get_local_intent_model()(inputs)
|
||||
|
@ -7,7 +7,6 @@ import numpy
|
||||
from nltk.corpus import stopwords # type:ignore
|
||||
from nltk.stem import WordNetLemmatizer # type:ignore
|
||||
from nltk.tokenize import word_tokenize # type:ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER
|
||||
@ -17,7 +16,6 @@ from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
||||
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
|
||||
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH
|
||||
from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW
|
||||
from danswer.db.feedback import create_query_event
|
||||
@ -75,17 +73,10 @@ def query_processing(
|
||||
|
||||
def embed_query(
|
||||
query: str,
|
||||
embedding_model: SentenceTransformer | None = None,
|
||||
prefix: str = ASYM_QUERY_PREFIX,
|
||||
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
|
||||
) -> list[float]:
|
||||
model = embedding_model or EmbeddingModel()
|
||||
prefixed_query = prefix + query
|
||||
query_embedding = model.encode(
|
||||
[prefixed_query], normalize_embeddings=normalize_embeddings
|
||||
)[0]
|
||||
|
||||
return query_embedding
|
||||
return EmbeddingModel().encode([prefixed_query])[0]
|
||||
|
||||
|
||||
def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]:
|
||||
|
Loading…
x
Reference in New Issue
Block a user