Warm up reranker (#2111)

This commit is contained in:
Yuhong Sun
2024-08-11 15:20:51 -07:00
committed by GitHub
parent 7fae66b766
commit 79523f2e0a
4 changed files with 80 additions and 44 deletions

View File

@@ -37,7 +37,7 @@ from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus from danswer.db.models import IndexModelStatus
from danswer.db.swap_index import check_index_swap from danswer.db.swap_index import check_index_swap
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
@@ -384,7 +384,7 @@ def update_loop(
if db_embedding_model.cloud_provider_id is None: if db_embedding_model.cloud_provider_id is None:
logger.debug("Running a first inference to warm up embedding model") logger.debug("Running a first inference to warm up embedding model")
warm_up_encoders( warm_up_bi_encoder(
embedding_model=db_embedding_model, embedding_model=db_embedding_model,
model_server_host=INDEXING_MODEL_SERVER_HOST, model_server_host=INDEXING_MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT, model_server_port=MODEL_SERVER_PORT,

View File

@@ -50,7 +50,7 @@ from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_sqlalchemy_engine from danswer.db.engine import get_sqlalchemy_engine
from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.one_shot_answer.models import ThreadMessage from danswer.one_shot_answer.models import ThreadMessage
from danswer.search.retrieval.search_runner import download_nltk_data from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.server.manage.models import SlackBotTokens from danswer.server.manage.models import SlackBotTokens
@@ -470,7 +470,7 @@ if __name__ == "__main__":
with Session(get_sqlalchemy_engine()) as db_session: with Session(get_sqlalchemy_engine()) as db_session:
embedding_model = get_current_db_embedding_model(db_session) embedding_model = get_current_db_embedding_model(db_session)
if embedding_model.cloud_provider_id is None: if embedding_model.cloud_provider_id is None:
warm_up_encoders( warm_up_bi_encoder(
embedding_model=embedding_model, embedding_model=embedding_model,
model_server_host=MODEL_SERVER_HOST, model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT, model_server_port=MODEL_SERVER_PORT,

View File

@@ -59,7 +59,8 @@ from danswer.document_index.interfaces import DocumentIndex
from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.llm.llm_initialization import load_llm_providers from danswer.llm.llm_initialization import load_llm_providers
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
from danswer.search.models import SavedSearchSettings from danswer.search.models import SavedSearchSettings
from danswer.search.retrieval.search_runner import download_nltk_data from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.search.search_settings import get_search_settings from danswer.search.search_settings import get_search_settings
@@ -293,26 +294,27 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
logger.info("Reranking is enabled.") logger.info("Reranking is enabled.")
if not DEFAULT_CROSS_ENCODER_MODEL_NAME: if not DEFAULT_CROSS_ENCODER_MODEL_NAME:
raise ValueError("No reranking model specified.") raise ValueError("No reranking model specified.")
search_settings = SavedSearchSettings(
update_search_settings( rerank_model_name=DEFAULT_CROSS_ENCODER_MODEL_NAME,
SavedSearchSettings( provider_type=RerankerProvider(DEFAULT_CROSS_ENCODER_PROVIDER_TYPE)
rerank_model_name=DEFAULT_CROSS_ENCODER_MODEL_NAME, if DEFAULT_CROSS_ENCODER_PROVIDER_TYPE
provider_type=RerankerProvider(DEFAULT_CROSS_ENCODER_PROVIDER_TYPE) else None,
if DEFAULT_CROSS_ENCODER_PROVIDER_TYPE is not None api_key=DEFAULT_CROSS_ENCODER_API_KEY,
else None, disable_rerank_for_streaming=DISABLE_RERANK_FOR_STREAMING,
api_key=DEFAULT_CROSS_ENCODER_API_KEY, num_rerank=NUM_POSTPROCESSED_RESULTS,
disable_rerank_for_streaming=DISABLE_RERANK_FOR_STREAMING, multilingual_expansion=[
num_rerank=NUM_POSTPROCESSED_RESULTS, s.strip()
multilingual_expansion=[ for s in MULTILINGUAL_QUERY_EXPANSION.split(",")
s.strip() if s.strip()
for s in MULTILINGUAL_QUERY_EXPANSION.split(",") ]
if s.strip() if MULTILINGUAL_QUERY_EXPANSION
] else [],
if MULTILINGUAL_QUERY_EXPANSION multipass_indexing=ENABLE_MULTIPASS_INDEXING,
else [],
multipass_indexing=ENABLE_MULTIPASS_INDEXING,
)
) )
update_search_settings(search_settings)
if search_settings.rerank_model_name and not search_settings.provider_type:
warm_up_cross_encoder(search_settings.rerank_model_name)
logger.info("Verifying query preprocessing (NLTK) data is downloaded") logger.info("Verifying query preprocessing (NLTK) data is downloaded")
download_nltk_data() download_nltk_data()
@@ -336,7 +338,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}") logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
if db_embedding_model.cloud_provider_id is None: if db_embedding_model.cloud_provider_id is None:
warm_up_encoders( warm_up_bi_encoder(
embedding_model=db_embedding_model, embedding_model=db_embedding_model,
model_server_host=MODEL_SERVER_HOST, model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT, model_server_port=MODEL_SERVER_PORT,

View File

@@ -1,5 +1,8 @@
import re import re
import time import time
from collections.abc import Callable
from functools import wraps
from typing import Any
import requests import requests
from httpx import HTTPError from httpx import HTTPError
@@ -31,6 +34,13 @@ from shared_configs.utils import batch_list
logger = setup_logger() logger = setup_logger()
WARM_UP_STRINGS = [
"Danswer is amazing!",
"Check out our easy deployment guide at",
"https://docs.danswer.dev/quickstart",
]
def clean_model_name(model_str: str) -> str: def clean_model_name(model_str: str) -> str:
return model_str.replace("/", "_").replace("-", "_").replace(".", "_") return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
@@ -281,7 +291,31 @@ class QueryAnalysisModel:
return response_model.is_keyword, response_model.keywords return response_model.is_keyword, response_model.keywords
def warm_up_encoders( def warm_up_retry(
func: Callable[..., Any],
tries: int = 20,
delay: int = 5,
*args: Any,
**kwargs: Any,
) -> Callable[..., Any]:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
exceptions = []
for attempt in range(tries):
try:
return func(*args, **kwargs)
except Exception as e:
exceptions.append(e)
logger.exception(
f"Attempt {attempt + 1} failed; retrying in {delay} seconds..."
)
time.sleep(delay)
raise Exception(f"All retries failed: {exceptions}")
return wrapper
def warm_up_bi_encoder(
embedding_model: DBEmbeddingModel, embedding_model: DBEmbeddingModel,
model_server_host: str = MODEL_SERVER_HOST, model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT, model_server_port: int = MODEL_SERVER_PORT,
@@ -289,12 +323,8 @@ def warm_up_encoders(
model_name = embedding_model.model_name model_name = embedding_model.model_name
normalize = embedding_model.normalize normalize = embedding_model.normalize
provider_type = embedding_model.provider_type provider_type = embedding_model.provider_type
warm_up_str = ( warm_up_str = " ".join(WARM_UP_STRINGS)
"Danswer is amazing! Check out our easy deployment guide at "
"https://docs.danswer.dev/quickstart"
)
# May not be the exact same tokenizer used for the indexing flow
logger.debug(f"Warming up encoder model: {model_name}") logger.debug(f"Warming up encoder model: {model_name}")
get_tokenizer(model_name=model_name, provider_type=provider_type).encode( get_tokenizer(model_name=model_name, provider_type=provider_type).encode(
warm_up_str warm_up_str
@@ -312,16 +342,20 @@ def warm_up_encoders(
api_key=None, api_key=None,
) )
# First time downloading the models it may take even longer, but just in case, retry_encode = warm_up_retry(embed_model.encode)
# retry the whole server retry_encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
wait_time = 5
for _ in range(20):
try: def warm_up_cross_encoder(
embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY) rerank_model_name: str,
return ) -> None:
except Exception: logger.debug(f"Warming up reranking model: {rerank_model_name}")
logger.exception(
f"Failed to run test embedding, retrying in {wait_time} seconds..." reranking_model = RerankingModel(
) model_name=rerank_model_name,
time.sleep(wait_time) provider_type=None,
raise Exception("Failed to run test embedding.") api_key=None,
)
retry_rerank = warm_up_retry(reranking_model.predict)
retry_rerank(WARM_UP_STRINGS[0], WARM_UP_STRINGS[1:])