mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-05 12:39:33 +02:00
Warm up reranker (#2111)
This commit is contained in:
parent
7fae66b766
commit
79523f2e0a
@ -37,7 +37,7 @@ from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
@ -384,7 +384,7 @@ def update_loop(
|
||||
|
||||
if db_embedding_model.cloud_provider_id is None:
|
||||
logger.debug("Running a first inference to warm up embedding model")
|
||||
warm_up_encoders(
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=db_embedding_model,
|
||||
model_server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
|
@ -50,7 +50,7 @@ from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
@ -470,7 +470,7 @@ if __name__ == "__main__":
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
if embedding_model.cloud_provider_id is None:
|
||||
warm_up_encoders(
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
|
@ -59,7 +59,8 @@ from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.llm.llm_initialization import load_llm_providers
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
|
||||
from danswer.search.models import SavedSearchSettings
|
||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.search.search_settings import get_search_settings
|
||||
@ -293,26 +294,27 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
logger.info("Reranking is enabled.")
|
||||
if not DEFAULT_CROSS_ENCODER_MODEL_NAME:
|
||||
raise ValueError("No reranking model specified.")
|
||||
|
||||
update_search_settings(
|
||||
SavedSearchSettings(
|
||||
rerank_model_name=DEFAULT_CROSS_ENCODER_MODEL_NAME,
|
||||
provider_type=RerankerProvider(DEFAULT_CROSS_ENCODER_PROVIDER_TYPE)
|
||||
if DEFAULT_CROSS_ENCODER_PROVIDER_TYPE is not None
|
||||
else None,
|
||||
api_key=DEFAULT_CROSS_ENCODER_API_KEY,
|
||||
disable_rerank_for_streaming=DISABLE_RERANK_FOR_STREAMING,
|
||||
num_rerank=NUM_POSTPROCESSED_RESULTS,
|
||||
multilingual_expansion=[
|
||||
s.strip()
|
||||
for s in MULTILINGUAL_QUERY_EXPANSION.split(",")
|
||||
if s.strip()
|
||||
]
|
||||
if MULTILINGUAL_QUERY_EXPANSION
|
||||
else [],
|
||||
multipass_indexing=ENABLE_MULTIPASS_INDEXING,
|
||||
)
|
||||
search_settings = SavedSearchSettings(
|
||||
rerank_model_name=DEFAULT_CROSS_ENCODER_MODEL_NAME,
|
||||
provider_type=RerankerProvider(DEFAULT_CROSS_ENCODER_PROVIDER_TYPE)
|
||||
if DEFAULT_CROSS_ENCODER_PROVIDER_TYPE
|
||||
else None,
|
||||
api_key=DEFAULT_CROSS_ENCODER_API_KEY,
|
||||
disable_rerank_for_streaming=DISABLE_RERANK_FOR_STREAMING,
|
||||
num_rerank=NUM_POSTPROCESSED_RESULTS,
|
||||
multilingual_expansion=[
|
||||
s.strip()
|
||||
for s in MULTILINGUAL_QUERY_EXPANSION.split(",")
|
||||
if s.strip()
|
||||
]
|
||||
if MULTILINGUAL_QUERY_EXPANSION
|
||||
else [],
|
||||
multipass_indexing=ENABLE_MULTIPASS_INDEXING,
|
||||
)
|
||||
update_search_settings(search_settings)
|
||||
|
||||
if search_settings.rerank_model_name and not search_settings.provider_type:
|
||||
warm_up_cross_encoder(search_settings.rerank_model_name)
|
||||
|
||||
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||
download_nltk_data()
|
||||
@ -336,7 +338,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
|
||||
logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
|
||||
if db_embedding_model.cloud_provider_id is None:
|
||||
warm_up_encoders(
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=db_embedding_model,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
|
@ -1,5 +1,8 @@
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from httpx import HTTPError
|
||||
@ -31,6 +34,13 @@ from shared_configs.utils import batch_list
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
WARM_UP_STRINGS = [
|
||||
"Danswer is amazing!",
|
||||
"Check out our easy deployment guide at",
|
||||
"https://docs.danswer.dev/quickstart",
|
||||
]
|
||||
|
||||
|
||||
def clean_model_name(model_str: str) -> str:
|
||||
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
|
||||
|
||||
@ -281,7 +291,31 @@ class QueryAnalysisModel:
|
||||
return response_model.is_keyword, response_model.keywords
|
||||
|
||||
|
||||
def warm_up_encoders(
|
||||
def warm_up_retry(
|
||||
func: Callable[..., Any],
|
||||
tries: int = 20,
|
||||
delay: int = 5,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Callable[..., Any]:
|
||||
@wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
exceptions = []
|
||||
for attempt in range(tries):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
exceptions.append(e)
|
||||
logger.exception(
|
||||
f"Attempt {attempt + 1} failed; retrying in {delay} seconds..."
|
||||
)
|
||||
time.sleep(delay)
|
||||
raise Exception(f"All retries failed: {exceptions}")
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def warm_up_bi_encoder(
|
||||
embedding_model: DBEmbeddingModel,
|
||||
model_server_host: str = MODEL_SERVER_HOST,
|
||||
model_server_port: int = MODEL_SERVER_PORT,
|
||||
@ -289,12 +323,8 @@ def warm_up_encoders(
|
||||
model_name = embedding_model.model_name
|
||||
normalize = embedding_model.normalize
|
||||
provider_type = embedding_model.provider_type
|
||||
warm_up_str = (
|
||||
"Danswer is amazing! Check out our easy deployment guide at "
|
||||
"https://docs.danswer.dev/quickstart"
|
||||
)
|
||||
warm_up_str = " ".join(WARM_UP_STRINGS)
|
||||
|
||||
# May not be the exact same tokenizer used for the indexing flow
|
||||
logger.debug(f"Warming up encoder model: {model_name}")
|
||||
get_tokenizer(model_name=model_name, provider_type=provider_type).encode(
|
||||
warm_up_str
|
||||
@ -312,16 +342,20 @@ def warm_up_encoders(
|
||||
api_key=None,
|
||||
)
|
||||
|
||||
# First time downloading the models it may take even longer, but just in case,
|
||||
# retry the whole server
|
||||
wait_time = 5
|
||||
for _ in range(20):
|
||||
try:
|
||||
embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
|
||||
return
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to run test embedding, retrying in {wait_time} seconds..."
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
raise Exception("Failed to run test embedding.")
|
||||
retry_encode = warm_up_retry(embed_model.encode)
|
||||
retry_encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
|
||||
|
||||
|
||||
def warm_up_cross_encoder(
|
||||
rerank_model_name: str,
|
||||
) -> None:
|
||||
logger.debug(f"Warming up reranking model: {rerank_model_name}")
|
||||
|
||||
reranking_model = RerankingModel(
|
||||
model_name=rerank_model_name,
|
||||
provider_type=None,
|
||||
api_key=None,
|
||||
)
|
||||
|
||||
retry_rerank = warm_up_retry(reranking_model.predict)
|
||||
retry_rerank(WARM_UP_STRINGS[0], WARM_UP_STRINGS[1:])
|
||||
|
Loading…
x
Reference in New Issue
Block a user