Fix empty / reverted embeddings (#1910)

This commit is contained in:
pablodanswer
2024-07-23 22:41:31 -07:00
committed by GitHub
parent 6ff8e6c0ea
commit 48a0d29a5c
8 changed files with 71 additions and 25 deletions

View File

@@ -80,7 +80,9 @@ class CloudEmbedding:
raise ValueError(f"Unsupported provider: {provider}")
self.client = _initialize_client(api_key, self.provider, model)
def _embed_openai(self, texts: list[str], model: str | None) -> list[list[float]]:
def _embed_openai(
self, texts: list[str], model: str | None
) -> list[list[float] | None]:
if model is None:
model = DEFAULT_OPENAI_MODEL
@@ -92,7 +94,7 @@ class CloudEmbedding:
def _embed_cohere(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float]]:
) -> list[list[float] | None]:
if model is None:
model = DEFAULT_COHERE_MODEL
@@ -108,7 +110,7 @@ class CloudEmbedding:
def _embed_voyage(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float]]:
) -> list[list[float] | None]:
if model is None:
model = DEFAULT_VOYAGE_MODEL
@@ -124,7 +126,7 @@ class CloudEmbedding:
def _embed_vertex(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float]]:
) -> list[list[float] | None]:
if model is None:
model = DEFAULT_VERTEX_MODEL
@@ -147,7 +149,7 @@ class CloudEmbedding:
texts: list[str],
text_type: EmbedTextType,
model_name: str | None = None,
) -> list[list[float]]:
) -> list[list[float] | None]:
try:
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(texts, model_name)
@@ -235,9 +237,20 @@ def embed_text(
api_key: str | None,
provider_type: str | None,
prefix: str | None,
) -> list[list[float]]:
) -> list[list[float] | None]:
non_empty_texts = []
empty_indices = []
for idx, text in enumerate(texts):
if text.strip():
non_empty_texts.append(text)
else:
empty_indices.append(idx)
# Third party API based embedding model
if provider_type is not None:
if not non_empty_texts:
embeddings = []
elif provider_type is not None:
logger.debug(f"Embedding text with provider: {provider_type}")
if api_key is None:
raise RuntimeError("API key not provided for cloud model")
@@ -254,14 +267,17 @@ def embed_text(
api_key=api_key, provider=provider_type, model=model_name
)
embeddings = cloud_model.embed(
texts=texts,
texts=non_empty_texts,
model_name=model_name,
text_type=text_type,
)
# Locally running model
elif model_name is not None:
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
prefixed_texts = (
[f"{prefix}{text}" for text in non_empty_texts]
if prefix
else non_empty_texts
)
local_model = get_embedding_model(
model_name=model_name, max_context_length=max_context_length
)
@@ -277,14 +293,26 @@ def embed_text(
if embeddings is None:
raise RuntimeError("Failed to create Embeddings")
if not isinstance(embeddings, list):
embeddings = embeddings.tolist()
embeddings_with_nulls: list[list[float] | None] = []
current_embedding_index = 0
for idx in range(len(texts)):
if idx in empty_indices:
embeddings_with_nulls.append(None)
else:
embedding = embeddings[current_embedding_index]
if isinstance(embedding, list) or embedding is None:
embeddings_with_nulls.append(embedding)
else:
embeddings_with_nulls.append(embedding.tolist())
current_embedding_index += 1
embeddings = embeddings_with_nulls
return embeddings
@simple_log_function_time()
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float] | None]:
cross_encoders = get_local_reranking_model_ensemble()
sim_scores = [
encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore