mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-12 14:12:53 +02:00
Fix issue where large docs/batches break openai embedding
This commit is contained in:
@ -34,6 +34,7 @@ from shared_configs.model_server_models import EmbedRequest
|
||||
from shared_configs.model_server_models import EmbedResponse
|
||||
from shared_configs.model_server_models import RerankRequest
|
||||
from shared_configs.model_server_models import RerankResponse
|
||||
from shared_configs.utils import batch_list
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@ -46,6 +47,11 @@ _RERANK_MODELS: Optional[list["CrossEncoder"]] = None
|
||||
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
|
||||
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
|
||||
|
||||
# OpenAI only allows 2048 embeddings to be computed at once
|
||||
_OPENAI_MAX_INPUT_LEN = 2048
|
||||
# Cohere allows up to 96 embeddings in a single embedding calling
|
||||
_COHERE_MAX_INPUT_LEN = 96
|
||||
|
||||
|
||||
def _initialize_client(
|
||||
api_key: str, provider: EmbeddingProvider, model: str | None = None
|
||||
@ -88,9 +94,14 @@ class CloudEmbedding:
|
||||
# OpenAI does not seem to provide truncation option, however
|
||||
# the context lengths used by Danswer currently are smaller than the max token length
|
||||
# for OpenAI embeddings so it's not a big deal
|
||||
final_embeddings: list[Embedding] = []
|
||||
try:
|
||||
response = self.client.embeddings.create(input=texts, model=model)
|
||||
return [embedding.embedding for embedding in response.data]
|
||||
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
|
||||
response = self.client.embeddings.create(input=text_batch, model=model)
|
||||
final_embeddings.extend(
|
||||
[embedding.embedding for embedding in response.data]
|
||||
)
|
||||
return final_embeddings
|
||||
except Exception as e:
|
||||
error_string = (
|
||||
f"Error embedding text with OpenAI: {str(e)} \n"
|
||||
@ -107,15 +118,18 @@ class CloudEmbedding:
|
||||
if model is None:
|
||||
model = DEFAULT_COHERE_MODEL
|
||||
|
||||
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
|
||||
# empirically it's only off by a very few tokens so it's not a big deal
|
||||
response = self.client.embed(
|
||||
texts=texts,
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
truncate="END",
|
||||
)
|
||||
return response.embeddings
|
||||
final_embeddings: list[Embedding] = []
|
||||
for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN):
|
||||
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
|
||||
# empirically it's only off by a very few tokens so it's not a big deal
|
||||
response = self.client.embed(
|
||||
texts=text_batch,
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
truncate="END",
|
||||
)
|
||||
final_embeddings.extend(response.embeddings)
|
||||
return final_embeddings
|
||||
|
||||
def _embed_voyage(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
|
Reference in New Issue
Block a user