This commit is contained in:
Yuhong Sun
2024-07-21 10:27:57 -07:00
parent eb3e7610fc
commit 44820b4909
3 changed files with 35 additions and 66 deletions

View File

@@ -1,4 +1,3 @@
import concurrent.futures
import gc
import json
from typing import Any
@@ -81,69 +80,70 @@ class CloudEmbedding:
raise ValueError(f"Unsupported provider: {provider}")
self.client = _initialize_client(api_key, self.provider, model)
def _embed_openai(self, text: str, model: str | None) -> list[float]:
def _embed_openai(self, texts: list[str], model: str | None) -> list[list[float]]:
if model is None:
model = DEFAULT_OPENAI_MODEL
response = self.client.embeddings.create(input=text, model=model)
return response.data[0].embedding
response = self.client.embeddings.create(input=texts, model=model)
return [embedding.embedding for embedding in response.data]
def _embed_cohere(
self, text: str, model: str | None, embedding_type: str
) -> list[float]:
self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float]]:
if model is None:
model = DEFAULT_COHERE_MODEL
response = self.client.embed(
texts=[text],
texts=texts,
model=model,
input_type=embedding_type,
)
return response.embeddings[0]
return response.embeddings
def _embed_voyage(
self, text: str, model: str | None, embedding_type: str
) -> list[float]:
self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float]]:
if model is None:
model = DEFAULT_VOYAGE_MODEL
response = self.client.embed(text, model=model, input_type=embedding_type)
return response.embeddings[0]
response = self.client.embed(texts, model=model, input_type=embedding_type)
return response.embeddings
def _embed_vertex(
self, text: str, model: str | None, embedding_type: str
) -> list[float]:
self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float]]:
if model is None:
model = DEFAULT_VERTEX_MODEL
embedding = self.client.get_embeddings(
embeddings = self.client.get_embeddings(
[
TextEmbeddingInput(
text,
embedding_type,
)
for text in texts
]
)
return embedding[0].values
return [embedding.values for embedding in embeddings]
def _embed(
def embed(
self,
*,
text: str,
texts: list[str],
text_type: EmbedTextType,
model: str | None = None,
) -> list[float]:
model_name: str | None = None,
) -> list[list[float]]:
try:
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(text, model)
return self._embed_openai(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(text, model, embedding_type)
return self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(text, model, embedding_type)
return self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(text, model, embedding_type)
return self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except Exception as e:
@@ -152,36 +152,6 @@ class CloudEmbedding:
detail=f"Error embedding text with {self.provider}: {str(e)}",
)
def encode(
self,
texts: list[str],
model_name: str | None,
text_type: EmbedTextType,
) -> list[list[float]]:
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(
self._embed,
text=text,
text_type=text_type,
model=model_name,
)
for text in texts
]
results = []
for future in concurrent.futures.as_completed(futures):
try:
results.append(future.result())
except Exception as e:
# Cancel all pending futures
for f in futures:
f.cancel()
# Raise the exception immediately
raise e
return results
@staticmethod
def create(
api_key: str, provider: str, model: str | None = None
@@ -252,6 +222,7 @@ def embed_text(
provider_type: str | None,
prefix: str | None,
) -> list[list[float]]:
# Third party API based embedding model
if provider_type is not None:
logger.debug(f"Embedding text with provider: {provider_type}")
if api_key is None:
@@ -268,18 +239,19 @@ def embed_text(
cloud_model = CloudEmbedding(
api_key=api_key, provider=provider_type, model=model_name
)
embeddings = cloud_model.encode(
embeddings = cloud_model.embed(
texts=texts,
model_name=model_name,
text_type=text_type,
)
# Locally running model
elif model_name is not None:
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
hosted_model = get_embedding_model(
local_model = get_embedding_model(
model_name=model_name, max_context_length=max_context_length
)
embeddings = hosted_model.encode(
embeddings = local_model.encode(
prefixed_texts, normalize_embeddings=normalize_embeddings
)