Warm up reranker (#2111)

This commit is contained in:
Yuhong Sun 2024-08-11 15:20:51 -07:00 committed by GitHub
parent 7fae66b766
commit 79523f2e0a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 80 additions and 44 deletions

View File

@ -37,7 +37,7 @@ from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.db.swap_index import check_index_swap
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
@ -384,7 +384,7 @@ def update_loop(
if db_embedding_model.cloud_provider_id is None:
logger.debug("Running a first inference to warm up embedding model")
warm_up_encoders(
warm_up_bi_encoder(
embedding_model=db_embedding_model,
model_server_host=INDEXING_MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,

View File

@ -50,7 +50,7 @@ from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_sqlalchemy_engine
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.one_shot_answer.models import ThreadMessage
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.server.manage.models import SlackBotTokens
@ -470,7 +470,7 @@ if __name__ == "__main__":
with Session(get_sqlalchemy_engine()) as db_session:
embedding_model = get_current_db_embedding_model(db_session)
if embedding_model.cloud_provider_id is None:
warm_up_encoders(
warm_up_bi_encoder(
embedding_model=embedding_model,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,

View File

@ -59,7 +59,8 @@ from danswer.document_index.interfaces import DocumentIndex
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.llm.llm_initialization import load_llm_providers
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
from danswer.search.models import SavedSearchSettings
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.search.search_settings import get_search_settings
@ -293,26 +294,27 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
logger.info("Reranking is enabled.")
if not DEFAULT_CROSS_ENCODER_MODEL_NAME:
raise ValueError("No reranking model specified.")
update_search_settings(
SavedSearchSettings(
rerank_model_name=DEFAULT_CROSS_ENCODER_MODEL_NAME,
provider_type=RerankerProvider(DEFAULT_CROSS_ENCODER_PROVIDER_TYPE)
if DEFAULT_CROSS_ENCODER_PROVIDER_TYPE is not None
else None,
api_key=DEFAULT_CROSS_ENCODER_API_KEY,
disable_rerank_for_streaming=DISABLE_RERANK_FOR_STREAMING,
num_rerank=NUM_POSTPROCESSED_RESULTS,
multilingual_expansion=[
s.strip()
for s in MULTILINGUAL_QUERY_EXPANSION.split(",")
if s.strip()
]
if MULTILINGUAL_QUERY_EXPANSION
else [],
multipass_indexing=ENABLE_MULTIPASS_INDEXING,
)
search_settings = SavedSearchSettings(
rerank_model_name=DEFAULT_CROSS_ENCODER_MODEL_NAME,
provider_type=RerankerProvider(DEFAULT_CROSS_ENCODER_PROVIDER_TYPE)
if DEFAULT_CROSS_ENCODER_PROVIDER_TYPE
else None,
api_key=DEFAULT_CROSS_ENCODER_API_KEY,
disable_rerank_for_streaming=DISABLE_RERANK_FOR_STREAMING,
num_rerank=NUM_POSTPROCESSED_RESULTS,
multilingual_expansion=[
s.strip()
for s in MULTILINGUAL_QUERY_EXPANSION.split(",")
if s.strip()
]
if MULTILINGUAL_QUERY_EXPANSION
else [],
multipass_indexing=ENABLE_MULTIPASS_INDEXING,
)
update_search_settings(search_settings)
if search_settings.rerank_model_name and not search_settings.provider_type:
warm_up_cross_encoder(search_settings.rerank_model_name)
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
download_nltk_data()
@ -336,7 +338,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
if db_embedding_model.cloud_provider_id is None:
warm_up_encoders(
warm_up_bi_encoder(
embedding_model=db_embedding_model,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,

View File

@ -1,5 +1,8 @@
import re
import time
from collections.abc import Callable
from functools import wraps
from typing import Any
import requests
from httpx import HTTPError
@ -31,6 +34,13 @@ from shared_configs.utils import batch_list
logger = setup_logger()
WARM_UP_STRINGS = [
"Danswer is amazing!",
"Check out our easy deployment guide at",
"https://docs.danswer.dev/quickstart",
]
def clean_model_name(model_str: str) -> str:
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
@ -281,7 +291,31 @@ class QueryAnalysisModel:
return response_model.is_keyword, response_model.keywords
def warm_up_encoders(
def warm_up_retry(
func: Callable[..., Any],
tries: int = 20,
delay: int = 5,
*args: Any,
**kwargs: Any,
) -> Callable[..., Any]:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
exceptions = []
for attempt in range(tries):
try:
return func(*args, **kwargs)
except Exception as e:
exceptions.append(e)
logger.exception(
f"Attempt {attempt + 1} failed; retrying in {delay} seconds..."
)
time.sleep(delay)
raise Exception(f"All retries failed: {exceptions}")
return wrapper
def warm_up_bi_encoder(
embedding_model: DBEmbeddingModel,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
@ -289,12 +323,8 @@ def warm_up_encoders(
model_name = embedding_model.model_name
normalize = embedding_model.normalize
provider_type = embedding_model.provider_type
warm_up_str = (
"Danswer is amazing! Check out our easy deployment guide at "
"https://docs.danswer.dev/quickstart"
)
warm_up_str = " ".join(WARM_UP_STRINGS)
# May not be the exact same tokenizer used for the indexing flow
logger.debug(f"Warming up encoder model: {model_name}")
get_tokenizer(model_name=model_name, provider_type=provider_type).encode(
warm_up_str
@ -312,16 +342,20 @@ def warm_up_encoders(
api_key=None,
)
# First time downloading the models it may take even longer, but just in case,
# retry the whole server
wait_time = 5
for _ in range(20):
try:
embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
return
except Exception:
logger.exception(
f"Failed to run test embedding, retrying in {wait_time} seconds..."
)
time.sleep(wait_time)
raise Exception("Failed to run test embedding.")
retry_encode = warm_up_retry(embed_model.encode)
retry_encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
def warm_up_cross_encoder(
rerank_model_name: str,
) -> None:
logger.debug(f"Warming up reranking model: {rerank_model_name}")
reranking_model = RerankingModel(
model_name=rerank_model_name,
provider_type=None,
api_key=None,
)
retry_rerank = warm_up_retry(reranking_model.predict)
retry_rerank(WARM_UP_STRINGS[0], WARM_UP_STRINGS[1:])