No Null Embeddings (#1982)

This commit is contained in:
Yuhong Sun
2024-07-30 19:54:49 -07:00
committed by GitHub
parent 60a87d9472
commit 036d5c737e
18 changed files with 132 additions and 146 deletions

View File

@ -29,6 +29,7 @@ from shared_configs.configs import CROSS_EMBED_CONTEXT_SIZE
from shared_configs.configs import CROSS_ENCODER_MODEL_ENSEMBLE
from shared_configs.configs import INDEXING_ONLY
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import Embedding
from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import RerankRequest
@ -80,9 +81,7 @@ class CloudEmbedding:
raise ValueError(f"Unsupported provider: {provider}")
self.client = _initialize_client(api_key, self.provider, model)
def _embed_openai(
self, texts: list[str], model: str | None
) -> list[list[float] | None]:
def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
if model is None:
model = DEFAULT_OPENAI_MODEL
@ -104,7 +103,7 @@ class CloudEmbedding:
def _embed_cohere(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float] | None]:
) -> list[Embedding]:
if model is None:
model = DEFAULT_COHERE_MODEL
@ -120,7 +119,7 @@ class CloudEmbedding:
def _embed_voyage(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float] | None]:
) -> list[Embedding]:
if model is None:
model = DEFAULT_VOYAGE_MODEL
@ -136,7 +135,7 @@ class CloudEmbedding:
def _embed_vertex(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float] | None]:
) -> list[Embedding]:
if model is None:
model = DEFAULT_VERTEX_MODEL
@ -159,7 +158,7 @@ class CloudEmbedding:
texts: list[str],
text_type: EmbedTextType,
model_name: str | None = None,
) -> list[list[float] | None]:
) -> list[Embedding]:
try:
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(texts, model_name)
@ -247,19 +246,13 @@ def embed_text(
api_key: str | None,
provider_type: str | None,
prefix: str | None,
) -> list[list[float] | None]:
non_empty_texts = []
empty_indices = []
for idx, text in enumerate(texts):
if text.strip():
non_empty_texts.append(text)
else:
empty_indices.append(idx)
) -> list[Embedding]:
if not all(texts):
raise ValueError("Empty strings are not allowed for embedding.")
# Third party API based embedding model
if not non_empty_texts:
embeddings = []
if not texts:
raise ValueError("No texts provided for embedding.")
elif provider_type is not None:
logger.debug(f"Embedding text with provider: {provider_type}")
if api_key is None:
@ -277,47 +270,36 @@ def embed_text(
api_key=api_key, provider=provider_type, model=model_name
)
embeddings = cloud_model.embed(
texts=non_empty_texts,
texts=texts,
model_name=model_name,
text_type=text_type,
)
# Check for None values in embeddings
if any(embedding is None for embedding in embeddings):
error_message = "Embeddings contain None values\n"
error_message += "Corresponding texts:\n"
error_message += "\n".join(texts)
raise ValueError(error_message)
elif model_name is not None:
prefixed_texts = (
[f"{prefix}{text}" for text in non_empty_texts]
if prefix
else non_empty_texts
)
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
local_model = get_embedding_model(
model_name=model_name, max_context_length=max_context_length
)
embeddings = local_model.encode(
embeddings_vectors = local_model.encode(
prefixed_texts, normalize_embeddings=normalize_embeddings
)
embeddings = [
embedding if isinstance(embedding, list) else embedding.tolist()
for embedding in embeddings_vectors
]
else:
raise ValueError(
"Either model name or provider must be provided to run embeddings."
)
if embeddings is None:
raise RuntimeError("Failed to create Embeddings")
embeddings_with_nulls: list[list[float] | None] = []
current_embedding_index = 0
for idx in range(len(texts)):
if idx in empty_indices:
embeddings_with_nulls.append(None)
else:
embedding = embeddings[current_embedding_index]
if isinstance(embedding, list) or embedding is None:
embeddings_with_nulls.append(embedding)
else:
embeddings_with_nulls.append(embedding.tolist())
current_embedding_index += 1
embeddings = embeddings_with_nulls
return embeddings
@ -337,6 +319,8 @@ async def process_embed_request(
) -> EmbedResponse:
if not embed_request.texts:
raise HTTPException(status_code=400, detail="No texts to be embedded")
elif not all(embed_request.texts):
raise ValueError("Empty strings are not allowed for embedding.")
try:
if embed_request.text_type == EmbedTextType.QUERY:
@ -371,8 +355,10 @@ async def process_rerank_request(embed_request: RerankRequest) -> RerankResponse
if not embed_request.documents or not embed_request.query:
raise HTTPException(
status_code=400, detail="No documents or query to be reranked"
status_code=400, detail="Missing documents or query for reranking"
)
if not all(embed_request.documents):
raise ValueError("Empty documents cannot be reranked.")
try:
sim_scores = calc_sim_scores(