Fix issue where large docs/batches break openai embedding

This commit is contained in:
Weves
2024-08-02 00:57:50 -07:00
committed by Chris Weaver
parent f280586e68
commit 51731ad0dd
6 changed files with 123 additions and 21 deletions

View File

@ -34,6 +34,7 @@ from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import RerankRequest
from shared_configs.model_server_models import RerankResponse
from shared_configs.utils import batch_list
logger = setup_logger()
@ -46,6 +47,11 @@ _RERANK_MODELS: Optional[list["CrossEncoder"]] = None
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
# OpenAI only allows 2048 embeddings to be computed at once
_OPENAI_MAX_INPUT_LEN = 2048
# Cohere allows up to 96 embeddings in a single embedding calling
_COHERE_MAX_INPUT_LEN = 96
def _initialize_client(
api_key: str, provider: EmbeddingProvider, model: str | None = None
@ -88,9 +94,14 @@ class CloudEmbedding:
# OpenAI does not seem to provide truncation option, however
# the context lengths used by Danswer currently are smaller than the max token length
# for OpenAI embeddings so it's not a big deal
final_embeddings: list[Embedding] = []
try:
response = self.client.embeddings.create(input=texts, model=model)
return [embedding.embedding for embedding in response.data]
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = self.client.embeddings.create(input=text_batch, model=model)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
return final_embeddings
except Exception as e:
error_string = (
f"Error embedding text with OpenAI: {str(e)} \n"
@ -107,15 +118,18 @@ class CloudEmbedding:
if model is None:
model = DEFAULT_COHERE_MODEL
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
# empirically it's only off by a very few tokens so it's not a big deal
response = self.client.embed(
texts=texts,
model=model,
input_type=embedding_type,
truncate="END",
)
return response.embeddings
final_embeddings: list[Embedding] = []
for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN):
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
# empirically it's only off by a very few tokens so it's not a big deal
response = self.client.embed(
texts=text_batch,
model=model,
input_type=embedding_type,
truncate="END",
)
final_embeddings.extend(response.embeddings)
return final_embeddings
def _embed_voyage(
self, texts: list[str], model: str | None, embedding_type: str