Pull out stripping of model suffix (#2175)

This commit is contained in:
pablodanswer 2024-08-20 11:32:03 -07:00 committed by GitHub
parent 12f0dbcfc5
commit 71c2b16a01
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -250,6 +250,11 @@ def embed_text(
if not all(texts):
raise ValueError("Empty strings are not allowed for embedding.")
# strip additional metadata from model name right before constructing embedding requests
stripped_model_name = (
model_name.removesuffix(ALT_INDEX_SUFFIX) if model_name else None
)
# Third party API based embedding model
if not texts:
raise ValueError("No texts provided for embedding.")
@ -267,11 +272,11 @@ def embed_text(
)
cloud_model = CloudEmbedding(
api_key=api_key, provider=provider_type, model=model_name
api_key=api_key, provider=provider_type, model=stripped_model_name
)
embeddings = cloud_model.embed(
texts=texts,
model_name=model_name,
model_name=stripped_model_name,
text_type=text_type,
)
@ -282,11 +287,9 @@ def embed_text(
error_message += "\n".join(texts)
raise ValueError(error_message)
elif model_name is not None:
elif stripped_model_name is not None:
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
# strip additional metadata from model name right before constructing from Huggingface
stripped_model_name = model_name.removesuffix(ALT_INDEX_SUFFIX)
local_model = get_embedding_model(
model_name=stripped_model_name, max_context_length=max_context_length
)