Pull out stripping of model suffix (#2175)

This commit is contained in:
pablodanswer 2024-08-20 11:32:03 -07:00 committed by GitHub
parent 12f0dbcfc5
commit 71c2b16a01
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -250,6 +250,11 @@ def embed_text(
if not all(texts): if not all(texts):
raise ValueError("Empty strings are not allowed for embedding.") raise ValueError("Empty strings are not allowed for embedding.")
# strip additional metadata from model name right before constructing embedding requests
stripped_model_name = (
model_name.removesuffix(ALT_INDEX_SUFFIX) if model_name else None
)
# Third party API based embedding model # Third party API based embedding model
if not texts: if not texts:
raise ValueError("No texts provided for embedding.") raise ValueError("No texts provided for embedding.")
@ -267,11 +272,11 @@ def embed_text(
) )
cloud_model = CloudEmbedding( cloud_model = CloudEmbedding(
api_key=api_key, provider=provider_type, model=model_name api_key=api_key, provider=provider_type, model=stripped_model_name
) )
embeddings = cloud_model.embed( embeddings = cloud_model.embed(
texts=texts, texts=texts,
model_name=model_name, model_name=stripped_model_name,
text_type=text_type, text_type=text_type,
) )
@ -282,11 +287,9 @@ def embed_text(
error_message += "\n".join(texts) error_message += "\n".join(texts)
raise ValueError(error_message) raise ValueError(error_message)
elif model_name is not None: elif stripped_model_name is not None:
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
# strip additional metadata from model name right before constructing from Huggingface
stripped_model_name = model_name.removesuffix(ALT_INDEX_SUFFIX)
local_model = get_embedding_model( local_model = get_embedding_model(
model_name=stripped_model_name, max_context_length=max_context_length model_name=stripped_model_name, max_context_length=max_context_length
) )