mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-03 08:20:40 +02:00
Warm Up Models Prep (#2196)
This commit is contained in:
parent
048cb8dd55
commit
bb1916d5d0
@ -10,8 +10,8 @@ from alembic import op
|
|||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision = "4b08d97e175a"
|
revision = "4b08d97e175a"
|
||||||
down_revision = "d9ec13955951"
|
down_revision = "d9ec13955951"
|
||||||
branch_labels = None
|
branch_labels: None = None
|
||||||
depends_on = None
|
depends_on: None = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
|
@ -109,8 +109,6 @@ class EmbeddingModelDetail(BaseModel):
|
|||||||
embedding_model: "EmbeddingModel",
|
embedding_model: "EmbeddingModel",
|
||||||
) -> "EmbeddingModelDetail":
|
) -> "EmbeddingModelDetail":
|
||||||
return cls(
|
return cls(
|
||||||
# When constructing EmbeddingModel Detail for user-facing flows, strip the
|
|
||||||
# unneeded additional data after the `_`s
|
|
||||||
model_name=embedding_model.model_name,
|
model_name=embedding_model.model_name,
|
||||||
model_dim=embedding_model.model_dim,
|
model_dim=embedding_model.model_dim,
|
||||||
normalize=embedding_model.normalize,
|
normalize=embedding_model.normalize,
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
@ -304,6 +305,7 @@ def warm_up_bi_encoder(
|
|||||||
embedding_model: DBEmbeddingModel,
|
embedding_model: DBEmbeddingModel,
|
||||||
model_server_host: str = MODEL_SERVER_HOST,
|
model_server_host: str = MODEL_SERVER_HOST,
|
||||||
model_server_port: int = MODEL_SERVER_PORT,
|
model_server_port: int = MODEL_SERVER_PORT,
|
||||||
|
non_blocking: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
model_name = embedding_model.model_name
|
model_name = embedding_model.model_name
|
||||||
normalize = embedding_model.normalize
|
normalize = embedding_model.normalize
|
||||||
@ -327,12 +329,26 @@ def warm_up_bi_encoder(
|
|||||||
api_key=None,
|
api_key=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
retry_encode = warm_up_retry(embed_model.encode)
|
def _warm_up() -> None:
|
||||||
retry_encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
|
try:
|
||||||
|
embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
|
||||||
|
logger.debug(f"Warm-up complete for encoder model: {model_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Warm-up request failed for encoder model {model_name}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if non_blocking:
|
||||||
|
threading.Thread(target=_warm_up, daemon=True).start()
|
||||||
|
logger.debug(f"Started non-blocking warm-up for encoder model: {model_name}")
|
||||||
|
else:
|
||||||
|
retry_encode = warm_up_retry(embed_model.encode)
|
||||||
|
retry_encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
|
||||||
|
|
||||||
|
|
||||||
def warm_up_cross_encoder(
|
def warm_up_cross_encoder(
|
||||||
rerank_model_name: str,
|
rerank_model_name: str,
|
||||||
|
non_blocking: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug(f"Warming up reranking model: {rerank_model_name}")
|
logger.debug(f"Warming up reranking model: {rerank_model_name}")
|
||||||
|
|
||||||
@ -342,5 +358,20 @@ def warm_up_cross_encoder(
|
|||||||
api_key=None,
|
api_key=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
retry_rerank = warm_up_retry(reranking_model.predict)
|
def _warm_up() -> None:
|
||||||
retry_rerank(WARM_UP_STRINGS[0], WARM_UP_STRINGS[1:])
|
try:
|
||||||
|
reranking_model.predict(WARM_UP_STRINGS[0], WARM_UP_STRINGS[1:])
|
||||||
|
logger.debug(f"Warm-up complete for reranking model: {rerank_model_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Warm-up request failed for reranking model {rerank_model_name}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if non_blocking:
|
||||||
|
threading.Thread(target=_warm_up, daemon=True).start()
|
||||||
|
logger.debug(
|
||||||
|
f"Started non-blocking warm-up for reranking model: {rerank_model_name}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
retry_rerank = warm_up_retry(reranking_model.predict)
|
||||||
|
retry_rerank(WARM_UP_STRINGS[0], WARM_UP_STRINGS[1:])
|
||||||
|
@ -311,11 +311,11 @@ export default function EmbeddingForm() {
|
|||||||
>
|
>
|
||||||
<>
|
<>
|
||||||
<div className="text-lg">
|
<div className="text-lg">
|
||||||
{selectedProvider.model_name} is a low-performance model.
|
{selectedProvider.model_name} is a lower accuracy model.
|
||||||
<br />
|
<br />
|
||||||
We recommend the following alternatives.
|
We recommend the following alternatives.
|
||||||
<li>OpenAI for cloud-based</li>
|
<li>Cohere embed-english-v3.0 for cloud-based</li>
|
||||||
<li>Nomic for self-hosted</li>
|
<li>Nomic nomic-embed-text-v1 for self-hosted</li>
|
||||||
</div>
|
</div>
|
||||||
<div className="flex mt-4 justify-between">
|
<div className="flex mt-4 justify-between">
|
||||||
<Button color="green" onClick={() => setShowPoorModel(false)}>
|
<Button color="green" onClick={() => setShowPoorModel(false)}>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user