Warm Up Models Prep (#2196)

This commit is contained in:
Yuhong Sun 2024-08-20 20:53:02 -07:00 committed by GitHub
parent 048cb8dd55
commit bb1916d5d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 40 additions and 11 deletions

View File

@ -10,8 +10,8 @@ from alembic import op
# revision identifiers, used by Alembic.
revision = "4b08d97e175a"
down_revision = "d9ec13955951"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -109,8 +109,6 @@ class EmbeddingModelDetail(BaseModel):
embedding_model: "EmbeddingModel",
) -> "EmbeddingModelDetail":
return cls(
# When constructing EmbeddingModel Detail for user-facing flows, strip the
# unneeded additional data after the `_`s
model_name=embedding_model.model_name,
model_dim=embedding_model.model_dim,
normalize=embedding_model.normalize,

View File

@ -1,4 +1,5 @@
import re
import threading
import time
from collections.abc import Callable
from functools import wraps
@ -304,6 +305,7 @@ def warm_up_bi_encoder(
embedding_model: DBEmbeddingModel,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
non_blocking: bool = False,
) -> None:
model_name = embedding_model.model_name
normalize = embedding_model.normalize
@ -327,12 +329,26 @@ def warm_up_bi_encoder(
api_key=None,
)
retry_encode = warm_up_retry(embed_model.encode)
retry_encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
def _warm_up() -> None:
try:
embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
logger.debug(f"Warm-up complete for encoder model: {model_name}")
except Exception as e:
logger.warning(
f"Warm-up request failed for encoder model {model_name}: {e}"
)
if non_blocking:
threading.Thread(target=_warm_up, daemon=True).start()
logger.debug(f"Started non-blocking warm-up for encoder model: {model_name}")
else:
retry_encode = warm_up_retry(embed_model.encode)
retry_encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
def warm_up_cross_encoder(
rerank_model_name: str,
non_blocking: bool = False,
) -> None:
logger.debug(f"Warming up reranking model: {rerank_model_name}")
@ -342,5 +358,20 @@ def warm_up_cross_encoder(
api_key=None,
)
retry_rerank = warm_up_retry(reranking_model.predict)
retry_rerank(WARM_UP_STRINGS[0], WARM_UP_STRINGS[1:])
def _warm_up() -> None:
try:
reranking_model.predict(WARM_UP_STRINGS[0], WARM_UP_STRINGS[1:])
logger.debug(f"Warm-up complete for reranking model: {rerank_model_name}")
except Exception as e:
logger.warning(
f"Warm-up request failed for reranking model {rerank_model_name}: {e}"
)
if non_blocking:
threading.Thread(target=_warm_up, daemon=True).start()
logger.debug(
f"Started non-blocking warm-up for reranking model: {rerank_model_name}"
)
else:
retry_rerank = warm_up_retry(reranking_model.predict)
retry_rerank(WARM_UP_STRINGS[0], WARM_UP_STRINGS[1:])

View File

@ -311,11 +311,11 @@ export default function EmbeddingForm() {
>
<>
<div className="text-lg">
{selectedProvider.model_name} is a low-performance model.
{selectedProvider.model_name} is a lower accuracy model.
<br />
We recommend the following alternatives.
<li>OpenAI for cloud-based</li>
<li>Nomic for self-hosted</li>
<li>Cohere embed-english-v3.0 for cloud-based</li>
<li>Nomic nomic-embed-text-v1 for self-hosted</li>
</div>
<div className="flex mt-4 justify-between">
<Button color="green" onClick={() => setShowPoorModel(false)}>