mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-20 13:05:49 +02:00
Warm up reranker (#2111)
This commit is contained in:
@@ -37,7 +37,7 @@ from danswer.db.models import IndexAttempt
|
|||||||
from danswer.db.models import IndexingStatus
|
from danswer.db.models import IndexingStatus
|
||||||
from danswer.db.models import IndexModelStatus
|
from danswer.db.models import IndexModelStatus
|
||||||
from danswer.db.swap_index import check_index_swap
|
from danswer.db.swap_index import check_index_swap
|
||||||
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
|
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from danswer.utils.variable_functionality import global_version
|
from danswer.utils.variable_functionality import global_version
|
||||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||||
@@ -384,7 +384,7 @@ def update_loop(
|
|||||||
|
|
||||||
if db_embedding_model.cloud_provider_id is None:
|
if db_embedding_model.cloud_provider_id is None:
|
||||||
logger.debug("Running a first inference to warm up embedding model")
|
logger.debug("Running a first inference to warm up embedding model")
|
||||||
warm_up_encoders(
|
warm_up_bi_encoder(
|
||||||
embedding_model=db_embedding_model,
|
embedding_model=db_embedding_model,
|
||||||
model_server_host=INDEXING_MODEL_SERVER_HOST,
|
model_server_host=INDEXING_MODEL_SERVER_HOST,
|
||||||
model_server_port=MODEL_SERVER_PORT,
|
model_server_port=MODEL_SERVER_PORT,
|
||||||
|
@@ -50,7 +50,7 @@ from danswer.danswerbot.slack.utils import respond_in_thread
|
|||||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||||
from danswer.db.engine import get_sqlalchemy_engine
|
from danswer.db.engine import get_sqlalchemy_engine
|
||||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||||
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
|
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||||
from danswer.one_shot_answer.models import ThreadMessage
|
from danswer.one_shot_answer.models import ThreadMessage
|
||||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||||
from danswer.server.manage.models import SlackBotTokens
|
from danswer.server.manage.models import SlackBotTokens
|
||||||
@@ -470,7 +470,7 @@ if __name__ == "__main__":
|
|||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with Session(get_sqlalchemy_engine()) as db_session:
|
||||||
embedding_model = get_current_db_embedding_model(db_session)
|
embedding_model = get_current_db_embedding_model(db_session)
|
||||||
if embedding_model.cloud_provider_id is None:
|
if embedding_model.cloud_provider_id is None:
|
||||||
warm_up_encoders(
|
warm_up_bi_encoder(
|
||||||
embedding_model=embedding_model,
|
embedding_model=embedding_model,
|
||||||
model_server_host=MODEL_SERVER_HOST,
|
model_server_host=MODEL_SERVER_HOST,
|
||||||
model_server_port=MODEL_SERVER_PORT,
|
model_server_port=MODEL_SERVER_PORT,
|
||||||
|
@@ -59,7 +59,8 @@ from danswer.document_index.interfaces import DocumentIndex
|
|||||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||||
from danswer.llm.llm_initialization import load_llm_providers
|
from danswer.llm.llm_initialization import load_llm_providers
|
||||||
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
|
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||||
|
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
|
||||||
from danswer.search.models import SavedSearchSettings
|
from danswer.search.models import SavedSearchSettings
|
||||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||||
from danswer.search.search_settings import get_search_settings
|
from danswer.search.search_settings import get_search_settings
|
||||||
@@ -293,26 +294,27 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
|||||||
logger.info("Reranking is enabled.")
|
logger.info("Reranking is enabled.")
|
||||||
if not DEFAULT_CROSS_ENCODER_MODEL_NAME:
|
if not DEFAULT_CROSS_ENCODER_MODEL_NAME:
|
||||||
raise ValueError("No reranking model specified.")
|
raise ValueError("No reranking model specified.")
|
||||||
|
search_settings = SavedSearchSettings(
|
||||||
update_search_settings(
|
rerank_model_name=DEFAULT_CROSS_ENCODER_MODEL_NAME,
|
||||||
SavedSearchSettings(
|
provider_type=RerankerProvider(DEFAULT_CROSS_ENCODER_PROVIDER_TYPE)
|
||||||
rerank_model_name=DEFAULT_CROSS_ENCODER_MODEL_NAME,
|
if DEFAULT_CROSS_ENCODER_PROVIDER_TYPE
|
||||||
provider_type=RerankerProvider(DEFAULT_CROSS_ENCODER_PROVIDER_TYPE)
|
else None,
|
||||||
if DEFAULT_CROSS_ENCODER_PROVIDER_TYPE is not None
|
api_key=DEFAULT_CROSS_ENCODER_API_KEY,
|
||||||
else None,
|
disable_rerank_for_streaming=DISABLE_RERANK_FOR_STREAMING,
|
||||||
api_key=DEFAULT_CROSS_ENCODER_API_KEY,
|
num_rerank=NUM_POSTPROCESSED_RESULTS,
|
||||||
disable_rerank_for_streaming=DISABLE_RERANK_FOR_STREAMING,
|
multilingual_expansion=[
|
||||||
num_rerank=NUM_POSTPROCESSED_RESULTS,
|
s.strip()
|
||||||
multilingual_expansion=[
|
for s in MULTILINGUAL_QUERY_EXPANSION.split(",")
|
||||||
s.strip()
|
if s.strip()
|
||||||
for s in MULTILINGUAL_QUERY_EXPANSION.split(",")
|
]
|
||||||
if s.strip()
|
if MULTILINGUAL_QUERY_EXPANSION
|
||||||
]
|
else [],
|
||||||
if MULTILINGUAL_QUERY_EXPANSION
|
multipass_indexing=ENABLE_MULTIPASS_INDEXING,
|
||||||
else [],
|
|
||||||
multipass_indexing=ENABLE_MULTIPASS_INDEXING,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
update_search_settings(search_settings)
|
||||||
|
|
||||||
|
if search_settings.rerank_model_name and not search_settings.provider_type:
|
||||||
|
warm_up_cross_encoder(search_settings.rerank_model_name)
|
||||||
|
|
||||||
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||||
download_nltk_data()
|
download_nltk_data()
|
||||||
@@ -336,7 +338,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
|||||||
|
|
||||||
logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
|
logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
|
||||||
if db_embedding_model.cloud_provider_id is None:
|
if db_embedding_model.cloud_provider_id is None:
|
||||||
warm_up_encoders(
|
warm_up_bi_encoder(
|
||||||
embedding_model=db_embedding_model,
|
embedding_model=db_embedding_model,
|
||||||
model_server_host=MODEL_SERVER_HOST,
|
model_server_host=MODEL_SERVER_HOST,
|
||||||
model_server_port=MODEL_SERVER_PORT,
|
model_server_port=MODEL_SERVER_PORT,
|
||||||
|
@@ -1,5 +1,8 @@
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from httpx import HTTPError
|
from httpx import HTTPError
|
||||||
@@ -31,6 +34,13 @@ from shared_configs.utils import batch_list
|
|||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
WARM_UP_STRINGS = [
|
||||||
|
"Danswer is amazing!",
|
||||||
|
"Check out our easy deployment guide at",
|
||||||
|
"https://docs.danswer.dev/quickstart",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def clean_model_name(model_str: str) -> str:
|
def clean_model_name(model_str: str) -> str:
|
||||||
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
|
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
|
||||||
|
|
||||||
@@ -281,7 +291,31 @@ class QueryAnalysisModel:
|
|||||||
return response_model.is_keyword, response_model.keywords
|
return response_model.is_keyword, response_model.keywords
|
||||||
|
|
||||||
|
|
||||||
def warm_up_encoders(
|
def warm_up_retry(
|
||||||
|
func: Callable[..., Any],
|
||||||
|
tries: int = 20,
|
||||||
|
delay: int = 5,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Callable[..., Any]:
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
exceptions = []
|
||||||
|
for attempt in range(tries):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
exceptions.append(e)
|
||||||
|
logger.exception(
|
||||||
|
f"Attempt {attempt + 1} failed; retrying in {delay} seconds..."
|
||||||
|
)
|
||||||
|
time.sleep(delay)
|
||||||
|
raise Exception(f"All retries failed: {exceptions}")
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def warm_up_bi_encoder(
|
||||||
embedding_model: DBEmbeddingModel,
|
embedding_model: DBEmbeddingModel,
|
||||||
model_server_host: str = MODEL_SERVER_HOST,
|
model_server_host: str = MODEL_SERVER_HOST,
|
||||||
model_server_port: int = MODEL_SERVER_PORT,
|
model_server_port: int = MODEL_SERVER_PORT,
|
||||||
@@ -289,12 +323,8 @@ def warm_up_encoders(
|
|||||||
model_name = embedding_model.model_name
|
model_name = embedding_model.model_name
|
||||||
normalize = embedding_model.normalize
|
normalize = embedding_model.normalize
|
||||||
provider_type = embedding_model.provider_type
|
provider_type = embedding_model.provider_type
|
||||||
warm_up_str = (
|
warm_up_str = " ".join(WARM_UP_STRINGS)
|
||||||
"Danswer is amazing! Check out our easy deployment guide at "
|
|
||||||
"https://docs.danswer.dev/quickstart"
|
|
||||||
)
|
|
||||||
|
|
||||||
# May not be the exact same tokenizer used for the indexing flow
|
|
||||||
logger.debug(f"Warming up encoder model: {model_name}")
|
logger.debug(f"Warming up encoder model: {model_name}")
|
||||||
get_tokenizer(model_name=model_name, provider_type=provider_type).encode(
|
get_tokenizer(model_name=model_name, provider_type=provider_type).encode(
|
||||||
warm_up_str
|
warm_up_str
|
||||||
@@ -312,16 +342,20 @@ def warm_up_encoders(
|
|||||||
api_key=None,
|
api_key=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# First time downloading the models it may take even longer, but just in case,
|
retry_encode = warm_up_retry(embed_model.encode)
|
||||||
# retry the whole server
|
retry_encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
|
||||||
wait_time = 5
|
|
||||||
for _ in range(20):
|
|
||||||
try:
|
def warm_up_cross_encoder(
|
||||||
embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
|
rerank_model_name: str,
|
||||||
return
|
) -> None:
|
||||||
except Exception:
|
logger.debug(f"Warming up reranking model: {rerank_model_name}")
|
||||||
logger.exception(
|
|
||||||
f"Failed to run test embedding, retrying in {wait_time} seconds..."
|
reranking_model = RerankingModel(
|
||||||
)
|
model_name=rerank_model_name,
|
||||||
time.sleep(wait_time)
|
provider_type=None,
|
||||||
raise Exception("Failed to run test embedding.")
|
api_key=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
retry_rerank = warm_up_retry(reranking_model.predict)
|
||||||
|
retry_rerank(WARM_UP_STRINGS[0], WARM_UP_STRINGS[1:])
|
||||||
|
Reference in New Issue
Block a user