Simplify index and model name swap logic (#2188)

This commit is contained in:
pablodanswer
2024-08-20 17:31:00 -07:00
committed by GitHub
parent ec6e2369a1
commit 53387ab3eb
5 changed files with 48 additions and 16 deletions

View File

@ -23,7 +23,6 @@ from model_server.constants import DEFAULT_VOYAGE_MODEL
from model_server.constants import EmbeddingModelTextType
from model_server.constants import EmbeddingProvider
from model_server.utils import simple_log_function_time
from shared_configs.configs import ALT_INDEX_SUFFIX
from shared_configs.configs import INDEXING_ONLY
from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
@ -250,11 +249,6 @@ def embed_text(
if not all(texts):
raise ValueError("Empty strings are not allowed for embedding.")
# strip additional metadata from model name right before constructing embedding requests
stripped_model_name = (
model_name.removesuffix(ALT_INDEX_SUFFIX) if model_name else None
)
# Third party API based embedding model
if not texts:
raise ValueError("No texts provided for embedding.")
@ -272,11 +266,11 @@ def embed_text(
)
cloud_model = CloudEmbedding(
api_key=api_key, provider=provider_type, model=stripped_model_name
api_key=api_key, provider=provider_type, model=model_name
)
embeddings = cloud_model.embed(
texts=texts,
model_name=stripped_model_name,
model_name=model_name,
text_type=text_type,
)
@ -287,11 +281,11 @@ def embed_text(
error_message += "\n".join(texts)
raise ValueError(error_message)
elif stripped_model_name is not None:
elif model_name is not None:
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
local_model = get_embedding_model(
model_name=stripped_model_name, max_context_length=max_context_length
model_name=model_name, max_context_length=max_context_length
)
embeddings_vectors = local_model.encode(
prefixed_texts, normalize_embeddings=normalize_embeddings