DAN-51 Model warm up on start (#27)

Also added a minor prompt update
This commit is contained in:
Yuhong Sun 2023-05-10 21:01:14 -07:00 committed by GitHub
parent 632a643b7a
commit 38bcb3ee6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 17 additions and 1 deletions

View File

@ -10,7 +10,8 @@ SYSTEM_ROLE = "You are a Question Answering system that answers queries based on
BASE_PROMPT = ( BASE_PROMPT = (
f"Answer the query based on provided documents and quote relevant sections. " f"Answer the query based on provided documents and quote relevant sections. "
f"Respond with a json containing a concise answer and up to three most relevant quotes from the documents.\n" f"Respond with a json containing a concise answer and up to three most relevant quotes from the documents. "
f"The quotes must be EXACT substrings from the documents.\n"
) )
UNABLE_TO_FIND_JSON_MSG = ( UNABLE_TO_FIND_JSON_MSG = (

View File

@ -86,6 +86,15 @@ def get_application() -> FastAPI:
RequestValidationError, validation_exception_handler RequestValidationError, validation_exception_handler
) )
@application.on_event("startup")
async def startup_event() -> None:
from danswer.semantic_search.semantic_search import (
warm_up_models,
) # To avoid circular imports
warm_up_models()
logger.info("Semantic Search models are ready.")
return application return application

View File

@ -66,6 +66,11 @@ def get_default_reranking_model() -> CrossEncoder:
return _RERANK_MODEL return _RERANK_MODEL
def warm_up_models() -> None:
get_default_embedding_model().encode("Danswer is so cool")
get_default_reranking_model().predict(("What is Danswer", "Enterprise QA")) # type: ignore
@log_function_time() @log_function_time()
def semantic_reranking( def semantic_reranking(
query: str, query: str,

View File

@ -26,5 +26,6 @@ def shared_precompare_cleanup(text: str) -> str:
text = text.replace(".", "") text = text.replace(".", "")
text = text.replace(":", "") text = text.replace(":", "")
text = text.replace(",", "") text = text.replace(",", "")
text = text.replace("-", "")
return text return text