diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index 2649be0fd1..f192c196be 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -37,7 +37,7 @@ from danswer.db.models import IndexAttempt from danswer.db.models import IndexingStatus from danswer.db.models import IndexModelStatus from danswer.db.swap_index import check_index_swap -from danswer.natural_language_processing.search_nlp_models import warm_up_encoders +from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder from danswer.utils.logger import setup_logger from danswer.utils.variable_functionality import global_version from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable @@ -384,7 +384,7 @@ def update_loop( if db_embedding_model.cloud_provider_id is None: logger.debug("Running a first inference to warm up embedding model") - warm_up_encoders( + warm_up_bi_encoder( embedding_model=db_embedding_model, model_server_host=INDEXING_MODEL_SERVER_HOST, model_server_port=MODEL_SERVER_PORT, diff --git a/backend/danswer/danswerbot/slack/listener.py b/backend/danswer/danswerbot/slack/listener.py index 4025d7a6a3..af080d0234 100644 --- a/backend/danswer/danswerbot/slack/listener.py +++ b/backend/danswer/danswerbot/slack/listener.py @@ -50,7 +50,7 @@ from danswer.danswerbot.slack.utils import respond_in_thread from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.engine import get_sqlalchemy_engine from danswer.dynamic_configs.interface import ConfigNotFoundError -from danswer.natural_language_processing.search_nlp_models import warm_up_encoders +from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder from danswer.one_shot_answer.models import ThreadMessage from danswer.search.retrieval.search_runner import download_nltk_data from danswer.server.manage.models import SlackBotTokens @@ -470,7 +470,7 @@ if __name__ == "__main__": with Session(get_sqlalchemy_engine()) as db_session: embedding_model = get_current_db_embedding_model(db_session) if embedding_model.cloud_provider_id is None: - warm_up_encoders( + warm_up_bi_encoder( embedding_model=embedding_model, model_server_host=MODEL_SERVER_HOST, model_server_port=MODEL_SERVER_PORT, diff --git a/backend/danswer/main.py b/backend/danswer/main.py index feec6a4a0c..549798adcd 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -59,7 +59,8 @@ from danswer.document_index.interfaces import DocumentIndex from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.llm.llm_initialization import load_llm_providers -from danswer.natural_language_processing.search_nlp_models import warm_up_encoders +from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder +from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder from danswer.search.models import SavedSearchSettings from danswer.search.retrieval.search_runner import download_nltk_data from danswer.search.search_settings import get_search_settings @@ -293,26 +294,27 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: logger.info("Reranking is enabled.") if not DEFAULT_CROSS_ENCODER_MODEL_NAME: raise ValueError("No reranking model specified.") - - update_search_settings( - SavedSearchSettings( - rerank_model_name=DEFAULT_CROSS_ENCODER_MODEL_NAME, - provider_type=RerankerProvider(DEFAULT_CROSS_ENCODER_PROVIDER_TYPE) - if DEFAULT_CROSS_ENCODER_PROVIDER_TYPE is not None - else None, - api_key=DEFAULT_CROSS_ENCODER_API_KEY, - disable_rerank_for_streaming=DISABLE_RERANK_FOR_STREAMING, - num_rerank=NUM_POSTPROCESSED_RESULTS, - multilingual_expansion=[ - s.strip() - for s in MULTILINGUAL_QUERY_EXPANSION.split(",") - if s.strip() - ] - if MULTILINGUAL_QUERY_EXPANSION - else [], - multipass_indexing=ENABLE_MULTIPASS_INDEXING, - ) + search_settings = SavedSearchSettings( + rerank_model_name=DEFAULT_CROSS_ENCODER_MODEL_NAME, + provider_type=RerankerProvider(DEFAULT_CROSS_ENCODER_PROVIDER_TYPE) + if DEFAULT_CROSS_ENCODER_PROVIDER_TYPE + else None, + api_key=DEFAULT_CROSS_ENCODER_API_KEY, + disable_rerank_for_streaming=DISABLE_RERANK_FOR_STREAMING, + num_rerank=NUM_POSTPROCESSED_RESULTS, + multilingual_expansion=[ + s.strip() + for s in MULTILINGUAL_QUERY_EXPANSION.split(",") + if s.strip() + ] + if MULTILINGUAL_QUERY_EXPANSION + else [], + multipass_indexing=ENABLE_MULTIPASS_INDEXING, ) + update_search_settings(search_settings) + + if search_settings.rerank_model_name and not search_settings.provider_type: + warm_up_cross_encoder(search_settings.rerank_model_name) logger.info("Verifying query preprocessing (NLTK) data is downloaded") download_nltk_data() @@ -336,7 +338,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}") if db_embedding_model.cloud_provider_id is None: - warm_up_encoders( + warm_up_bi_encoder( embedding_model=db_embedding_model, model_server_host=MODEL_SERVER_HOST, model_server_port=MODEL_SERVER_PORT, diff --git a/backend/danswer/natural_language_processing/search_nlp_models.py b/backend/danswer/natural_language_processing/search_nlp_models.py index 82526dd422..c34d57cb81 100644 --- a/backend/danswer/natural_language_processing/search_nlp_models.py +++ b/backend/danswer/natural_language_processing/search_nlp_models.py @@ -1,5 +1,8 @@ import re import time +from collections.abc import Callable +from functools import wraps +from typing import Any import requests from httpx import HTTPError @@ -31,6 +34,13 @@ from shared_configs.utils import batch_list logger = setup_logger() +WARM_UP_STRINGS = [ + "Danswer is amazing!", + "Check out our easy deployment guide at", + "https://docs.danswer.dev/quickstart", +] + + def clean_model_name(model_str: str) -> str: return model_str.replace("/", "_").replace("-", "_").replace(".", "_") @@ -281,7 +291,31 @@ class QueryAnalysisModel: return response_model.is_keyword, response_model.keywords -def warm_up_encoders( +def warm_up_retry( + func: Callable[..., Any], + tries: int = 20, + delay: int = 5, + *args: Any, + **kwargs: Any, +) -> Callable[..., Any]: + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + exceptions = [] + for attempt in range(tries): + try: + return func(*args, **kwargs) + except Exception as e: + exceptions.append(e) + logger.exception( + f"Attempt {attempt + 1} failed; retrying in {delay} seconds..." + ) + time.sleep(delay) + raise Exception(f"All retries failed: {exceptions}") + + return wrapper + + +def warm_up_bi_encoder( embedding_model: DBEmbeddingModel, model_server_host: str = MODEL_SERVER_HOST, model_server_port: int = MODEL_SERVER_PORT, @@ -289,12 +323,8 @@ def warm_up_encoders( model_name = embedding_model.model_name normalize = embedding_model.normalize provider_type = embedding_model.provider_type - warm_up_str = ( - "Danswer is amazing! Check out our easy deployment guide at " - "https://docs.danswer.dev/quickstart" - ) + warm_up_str = " ".join(WARM_UP_STRINGS) - # May not be the exact same tokenizer used for the indexing flow logger.debug(f"Warming up encoder model: {model_name}") get_tokenizer(model_name=model_name, provider_type=provider_type).encode( warm_up_str @@ -312,16 +342,20 @@ def warm_up_encoders( api_key=None, ) - # First time downloading the models it may take even longer, but just in case, - # retry the whole server - wait_time = 5 - for _ in range(20): - try: - embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY) - return - except Exception: - logger.exception( - f"Failed to run test embedding, retrying in {wait_time} seconds..." - ) - time.sleep(wait_time) - raise Exception("Failed to run test embedding.") + retry_encode = warm_up_retry(embed_model.encode) + retry_encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY) + + +def warm_up_cross_encoder( + rerank_model_name: str, +) -> None: + logger.debug(f"Warming up reranking model: {rerank_model_name}") + + reranking_model = RerankingModel( + model_name=rerank_model_name, + provider_type=None, + api_key=None, + ) + + retry_rerank = warm_up_retry(reranking_model.predict) + retry_rerank(WARM_UP_STRINGS[0], WARM_UP_STRINGS[1:])