mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-06 13:09:39 +02:00
NLP Model Warmup Reworked (#748)
This commit is contained in:
parent
6e9f31d1e9
commit
57f0323f52
@ -13,7 +13,7 @@ from danswer.utils.timing import log_function_time
|
|||||||
|
|
||||||
|
|
||||||
@log_function_time()
|
@log_function_time()
|
||||||
def encode_chunks(
|
def embed_chunks(
|
||||||
chunks: list[DocAwareChunk],
|
chunks: list[DocAwareChunk],
|
||||||
embedding_model: SentenceTransformer | None = None,
|
embedding_model: SentenceTransformer | None = None,
|
||||||
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||||
@ -67,4 +67,4 @@ def encode_chunks(
|
|||||||
|
|
||||||
class DefaultEmbedder(Embedder):
|
class DefaultEmbedder(Embedder):
|
||||||
def embed(self, chunks: list[DocAwareChunk]) -> list[IndexChunk]:
|
def embed(self, chunks: list[DocAwareChunk]) -> list[IndexChunk]:
|
||||||
return encode_chunks(chunks)
|
return embed_chunks(chunks)
|
||||||
|
@ -200,14 +200,14 @@ def get_application() -> FastAPI:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("Warming up local NLP models.")
|
logger.info("Warming up local NLP models.")
|
||||||
|
warm_up_models()
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
logger.info("GPU is available")
|
logger.info("GPU is available")
|
||||||
else:
|
else:
|
||||||
logger.info("GPU is not available")
|
logger.info("GPU is not available")
|
||||||
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
||||||
|
|
||||||
warm_up_models()
|
|
||||||
|
|
||||||
# This is for the LLM, most LLMs will not need warming up
|
# This is for the LLM, most LLMs will not need warming up
|
||||||
get_default_llm().log_model_configs()
|
get_default_llm().log_model_configs()
|
||||||
get_default_qa_model().warm_up_model()
|
get_default_qa_model().warm_up_model()
|
||||||
|
@ -46,7 +46,8 @@ def get_local_embedding_model(
|
|||||||
max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||||
) -> SentenceTransformer:
|
) -> SentenceTransformer:
|
||||||
global _EMBED_MODEL
|
global _EMBED_MODEL
|
||||||
if _EMBED_MODEL is None:
|
if _EMBED_MODEL is None or max_context_length != _EMBED_MODEL.max_seq_length:
|
||||||
|
logger.info(f"Loading {model_name}")
|
||||||
_EMBED_MODEL = SentenceTransformer(model_name)
|
_EMBED_MODEL = SentenceTransformer(model_name)
|
||||||
_EMBED_MODEL.max_seq_length = max_context_length
|
_EMBED_MODEL.max_seq_length = max_context_length
|
||||||
return _EMBED_MODEL
|
return _EMBED_MODEL
|
||||||
@ -57,10 +58,13 @@ def get_local_reranking_model_ensemble(
|
|||||||
max_context_length: int = CROSS_EMBED_CONTEXT_SIZE,
|
max_context_length: int = CROSS_EMBED_CONTEXT_SIZE,
|
||||||
) -> list[CrossEncoder]:
|
) -> list[CrossEncoder]:
|
||||||
global _RERANK_MODELS
|
global _RERANK_MODELS
|
||||||
if _RERANK_MODELS is None:
|
if _RERANK_MODELS is None or max_context_length != _RERANK_MODELS[0].max_length:
|
||||||
_RERANK_MODELS = [CrossEncoder(model_name) for model_name in model_names]
|
_RERANK_MODELS = []
|
||||||
for model in _RERANK_MODELS:
|
for model_name in model_names:
|
||||||
|
logger.info(f"Loading {model_name}")
|
||||||
|
model = CrossEncoder(model_name)
|
||||||
model.max_length = max_context_length
|
model.max_length = max_context_length
|
||||||
|
_RERANK_MODELS.append(model)
|
||||||
return _RERANK_MODELS
|
return _RERANK_MODELS
|
||||||
|
|
||||||
|
|
||||||
@ -76,7 +80,7 @@ def get_local_intent_model(
|
|||||||
max_context_length: int = QUERY_MAX_CONTEXT_SIZE,
|
max_context_length: int = QUERY_MAX_CONTEXT_SIZE,
|
||||||
) -> TFDistilBertForSequenceClassification:
|
) -> TFDistilBertForSequenceClassification:
|
||||||
global _INTENT_MODEL
|
global _INTENT_MODEL
|
||||||
if _INTENT_MODEL is None:
|
if _INTENT_MODEL is None or max_context_length != _INTENT_MODEL.max_seq_length:
|
||||||
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
|
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
|
||||||
model_name
|
model_name
|
||||||
)
|
)
|
||||||
@ -84,30 +88,6 @@ def get_local_intent_model(
|
|||||||
return _INTENT_MODEL
|
return _INTENT_MODEL
|
||||||
|
|
||||||
|
|
||||||
def warm_up_models(
|
|
||||||
indexer_only: bool = False, skip_cross_encoders: bool = SKIP_RERANKING
|
|
||||||
) -> None:
|
|
||||||
warm_up_str = "Danswer is amazing"
|
|
||||||
get_default_tokenizer()(warm_up_str)
|
|
||||||
get_local_embedding_model().encode(warm_up_str)
|
|
||||||
|
|
||||||
if indexer_only:
|
|
||||||
return
|
|
||||||
|
|
||||||
if not skip_cross_encoders:
|
|
||||||
cross_encoders = get_local_reranking_model_ensemble()
|
|
||||||
[
|
|
||||||
cross_encoder.predict((warm_up_str, warm_up_str))
|
|
||||||
for cross_encoder in cross_encoders
|
|
||||||
]
|
|
||||||
|
|
||||||
intent_tokenizer = get_intent_model_tokenizer()
|
|
||||||
inputs = intent_tokenizer(
|
|
||||||
warm_up_str, return_tensors="tf", truncation=True, padding=True
|
|
||||||
)
|
|
||||||
get_local_intent_model()(inputs)
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingModel:
|
class EmbeddingModel:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -269,3 +249,27 @@ class IntentModel:
|
|||||||
class_percentages = np.round(probabilities.numpy() * 100, 2)
|
class_percentages = np.round(probabilities.numpy() * 100, 2)
|
||||||
|
|
||||||
return list(class_percentages.tolist()[0])
|
return list(class_percentages.tolist()[0])
|
||||||
|
|
||||||
|
|
||||||
|
def warm_up_models(
|
||||||
|
indexer_only: bool = False, skip_cross_encoders: bool = SKIP_RERANKING
|
||||||
|
) -> None:
|
||||||
|
warm_up_str = (
|
||||||
|
"Danswer is amazing! Check out our easy deployment guide at "
|
||||||
|
"https://docs.danswer.dev/quickstart"
|
||||||
|
)
|
||||||
|
get_default_tokenizer()(warm_up_str)
|
||||||
|
|
||||||
|
EmbeddingModel().encode(texts=[warm_up_str])
|
||||||
|
|
||||||
|
if indexer_only:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not skip_cross_encoders:
|
||||||
|
CrossEncoderEnsembleModel().predict(query=warm_up_str, passages=[warm_up_str])
|
||||||
|
|
||||||
|
intent_tokenizer = get_intent_model_tokenizer()
|
||||||
|
inputs = intent_tokenizer(
|
||||||
|
warm_up_str, return_tensors="tf", truncation=True, padding=True
|
||||||
|
)
|
||||||
|
get_local_intent_model()(inputs)
|
||||||
|
@ -7,7 +7,6 @@ import numpy
|
|||||||
from nltk.corpus import stopwords # type:ignore
|
from nltk.corpus import stopwords # type:ignore
|
||||||
from nltk.stem import WordNetLemmatizer # type:ignore
|
from nltk.stem import WordNetLemmatizer # type:ignore
|
||||||
from nltk.tokenize import word_tokenize # type:ignore
|
from nltk.tokenize import word_tokenize # type:ignore
|
||||||
from sentence_transformers import SentenceTransformer # type: ignore
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER
|
from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER
|
||||||
@ -17,7 +16,6 @@ from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
|||||||
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
|
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
|
||||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
||||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
|
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
|
||||||
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
|
||||||
from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH
|
from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH
|
||||||
from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW
|
from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW
|
||||||
from danswer.db.feedback import create_query_event
|
from danswer.db.feedback import create_query_event
|
||||||
@ -75,17 +73,10 @@ def query_processing(
|
|||||||
|
|
||||||
def embed_query(
|
def embed_query(
|
||||||
query: str,
|
query: str,
|
||||||
embedding_model: SentenceTransformer | None = None,
|
|
||||||
prefix: str = ASYM_QUERY_PREFIX,
|
prefix: str = ASYM_QUERY_PREFIX,
|
||||||
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
|
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
model = embedding_model or EmbeddingModel()
|
|
||||||
prefixed_query = prefix + query
|
prefixed_query = prefix + query
|
||||||
query_embedding = model.encode(
|
return EmbeddingModel().encode([prefixed_query])[0]
|
||||||
[prefixed_query], normalize_embeddings=normalize_embeddings
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
return query_embedding
|
|
||||||
|
|
||||||
|
|
||||||
def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]:
|
def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user