Added retries and multithreading for cloud embedding (#1879)

* added retries and multithreading for cloud embedding

* refactored a bit

* cleaned up code

* got the errors to bubble up to the ui correctly

* added exceptin printing

* added requirements

* touchups

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
This commit is contained in:
hagen-danswer
2024-07-20 22:10:18 -07:00
committed by GitHub
parent 7fbbb174bb
commit eb3e7610fc
5 changed files with 203 additions and 110 deletions

View File

@@ -1,3 +1,4 @@
import concurrent.futures
import gc
import json
from typing import Any
@@ -10,6 +11,7 @@ from cohere import Client as CohereClient
from fastapi import APIRouter
from fastapi import HTTPException
from google.oauth2 import service_account # type: ignore
from retry import retry
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from vertexai.language_models import TextEmbeddingInput # type: ignore
@@ -40,65 +42,44 @@ router = APIRouter(prefix="/encoder")
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
# If we are not only indexing, dont want retry very long
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
def _initialize_client(
api_key: str, provider: EmbeddingProvider, model: str | None = None
) -> Any:
if provider == EmbeddingProvider.OPENAI:
return openai.OpenAI(api_key=api_key)
elif provider == EmbeddingProvider.COHERE:
return CohereClient(api_key=api_key)
elif provider == EmbeddingProvider.VOYAGE:
return voyageai.Client(api_key=api_key)
elif provider == EmbeddingProvider.GOOGLE:
credentials = service_account.Credentials.from_service_account_info(
json.loads(api_key)
)
project_id = json.loads(api_key)["project_id"]
vertexai.init(project=project_id, credentials=credentials)
return TextEmbeddingModel.from_pretrained(model or DEFAULT_VERTEX_MODEL)
else:
raise ValueError(f"Unsupported provider: {provider}")
class CloudEmbedding:
def __init__(self, api_key: str, provider: str, model: str | None = None):
self.api_key = api_key
def __init__(
self,
api_key: str,
provider: str,
# Only for Google as is needed on client setup
self.model = model
model: str | None = None,
) -> None:
try:
self.provider = EmbeddingProvider(provider.lower())
except ValueError:
raise ValueError(f"Unsupported provider: {provider}")
self.client = self._initialize_client()
def _initialize_client(self) -> Any:
if self.provider == EmbeddingProvider.OPENAI:
return openai.OpenAI(api_key=self.api_key)
elif self.provider == EmbeddingProvider.COHERE:
return CohereClient(api_key=self.api_key)
elif self.provider == EmbeddingProvider.VOYAGE:
return voyageai.Client(api_key=self.api_key)
elif self.provider == EmbeddingProvider.GOOGLE:
credentials = service_account.Credentials.from_service_account_info(
json.loads(self.api_key)
)
project_id = json.loads(self.api_key)["project_id"]
vertexai.init(project=project_id, credentials=credentials)
return TextEmbeddingModel.from_pretrained(
self.model or DEFAULT_VERTEX_MODEL
)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
def encode(
self, texts: list[str], model_name: str | None, text_type: EmbedTextType
) -> list[list[float]]:
return [
self.embed(text=text, text_type=text_type, model=model_name)
for text in texts
]
def embed(
self, *, text: str, text_type: EmbedTextType, model: str | None = None
) -> list[float]:
logger.debug(f"Embedding text with provider: {self.provider}")
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(text, model)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(text, model, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(text, model, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(text, model, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
self.client = _initialize_client(api_key, self.provider, model)
def _embed_openai(self, text: str, model: str | None) -> list[float]:
if model is None:
@@ -145,6 +126,62 @@ class CloudEmbedding:
)
return embedding[0].values
def _embed(
self,
*,
text: str,
text_type: EmbedTextType,
model: str | None = None,
) -> list[float]:
try:
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(text, model)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(text, model, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(text, model, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(text, model, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error embedding text with {self.provider}: {str(e)}",
)
def encode(
self,
texts: list[str],
model_name: str | None,
text_type: EmbedTextType,
) -> list[list[float]]:
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(
self._embed,
text=text,
text_type=text_type,
model=model_name,
)
for text in texts
]
results = []
for future in concurrent.futures.as_completed(futures):
try:
results.append(future.result())
except Exception as e:
# Cancel all pending futures
for f in futures:
f.cancel()
# Raise the exception immediately
raise e
return results
@staticmethod
def create(
api_key: str, provider: str, model: str | None = None
@@ -204,6 +241,7 @@ def warm_up_cross_encoders() -> None:
@simple_log_function_time()
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
def embed_text(
texts: list[str],
text_type: EmbedTextType,
@@ -212,29 +250,50 @@ def embed_text(
normalize_embeddings: bool,
api_key: str | None,
provider_type: str | None,
prefix: str | None,
) -> list[list[float]]:
if provider_type is not None:
logger.debug(f"Embedding text with provider: {provider_type}")
if api_key is None:
raise RuntimeError("API key not provided for cloud model")
if prefix:
# This may change in the future if some providers require the user
# to manually append a prefix but this is not the case currently
raise ValueError(
"Prefix string is not valid for cloud models. "
"Cloud models take an explicit text type instead."
)
cloud_model = CloudEmbedding(
api_key=api_key, provider=provider_type, model=model_name
)
embeddings = cloud_model.encode(texts, model_name, text_type)
embeddings = cloud_model.encode(
texts=texts,
model_name=model_name,
text_type=text_type,
)
elif model_name is not None:
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
hosted_model = get_embedding_model(
model_name=model_name, max_context_length=max_context_length
)
embeddings = hosted_model.encode(
texts, normalize_embeddings=normalize_embeddings
prefixed_texts, normalize_embeddings=normalize_embeddings
)
else:
raise ValueError(
"Either model name or provider must be provided to run embeddings."
)
if embeddings is None:
raise RuntimeError("Embeddings were not created")
raise RuntimeError("Failed to create Embeddings")
if not isinstance(embeddings, list):
embeddings = embeddings.tolist()
return embeddings
@@ -253,6 +312,13 @@ async def process_embed_request(
embed_request: EmbedRequest,
) -> EmbedResponse:
try:
if embed_request.text_type == EmbedTextType.QUERY:
prefix = embed_request.manual_query_prefix
elif embed_request.text_type == EmbedTextType.PASSAGE:
prefix = embed_request.manual_passage_prefix
else:
prefix = None
embeddings = embed_text(
texts=embed_request.texts,
model_name=embed_request.model_name,
@@ -261,13 +327,13 @@ async def process_embed_request(
api_key=embed_request.api_key,
provider_type=embed_request.provider_type,
text_type=embed_request.text_type,
prefix=prefix,
)
return EmbedResponse(embeddings=embeddings)
except Exception as e:
logger.exception(f"Error during embedding process:\n{str(e)}")
raise HTTPException(
status_code=500, detail="Failed to run Bi-Encoder embedding"
)
exception_detail = f"Error during embedding process:\n{str(e)}"
logger.exception(exception_detail)
raise HTTPException(status_code=500, detail=exception_detail)
@router.post("/cross-encoder-scores")