mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 03:48:14 +02:00
Fix issue where large docs/batches break openai embedding
This commit is contained in:
parent
f280586e68
commit
51731ad0dd
@ -38,6 +38,8 @@ ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "search_query: ")
|
||||
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "search_document: ")
|
||||
# Purely an optimization, memory limitation consideration
|
||||
BATCH_SIZE_ENCODE_CHUNKS = 8
|
||||
# don't send over too many chunks at once, as sending too many could cause timeouts
|
||||
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES = 512
|
||||
# For score display purposes, only way is to know the expected ranges
|
||||
CROSS_ENCODER_RANGE_MAX = 1
|
||||
CROSS_ENCODER_RANGE_MIN = 0
|
||||
|
@ -4,11 +4,13 @@ import requests
|
||||
from httpx import HTTPError
|
||||
|
||||
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||
from danswer.configs.model_configs import (
|
||||
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES,
|
||||
)
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.db.models import EmbeddingModel as DBEmbeddingModel
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.utils.batching import batch_list
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
@ -20,6 +22,7 @@ from shared_configs.model_server_models import IntentRequest
|
||||
from shared_configs.model_server_models import IntentResponse
|
||||
from shared_configs.model_server_models import RerankRequest
|
||||
from shared_configs.model_server_models import RerankResponse
|
||||
from shared_configs.utils import batch_list
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -73,7 +76,8 @@ class EmbeddingModel:
|
||||
self,
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||
local_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||
api_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES,
|
||||
) -> list[Embedding]:
|
||||
if not texts or not all(texts):
|
||||
raise ValueError(f"Empty or missing text for embedding: {texts}")
|
||||
@ -95,6 +99,7 @@ class EmbeddingModel:
|
||||
]
|
||||
|
||||
if self.provider_type:
|
||||
text_batches = batch_list(texts, api_embedding_batch_size)
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=texts,
|
||||
@ -120,7 +125,7 @@ class EmbeddingModel:
|
||||
return EmbedResponse(**response.json()).embeddings
|
||||
|
||||
# Batching for local embedding
|
||||
text_batches = batch_list(texts, batch_size)
|
||||
text_batches = batch_list(texts, local_embedding_batch_size)
|
||||
embeddings: list[Embedding] = []
|
||||
logger.debug(
|
||||
f"Encoding {len(texts)} texts in {len(text_batches)} batches for local model"
|
||||
|
@ -21,10 +21,3 @@ def batch_generator(
|
||||
if pre_batch_yield:
|
||||
pre_batch_yield(batch)
|
||||
yield batch
|
||||
|
||||
|
||||
def batch_list(
|
||||
lst: list[T],
|
||||
batch_size: int,
|
||||
) -> list[list[T]]:
|
||||
return [lst[i : i + batch_size] for i in range(0, len(lst), batch_size)]
|
||||
|
@ -34,6 +34,7 @@ from shared_configs.model_server_models import EmbedRequest
|
||||
from shared_configs.model_server_models import EmbedResponse
|
||||
from shared_configs.model_server_models import RerankRequest
|
||||
from shared_configs.model_server_models import RerankResponse
|
||||
from shared_configs.utils import batch_list
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@ -46,6 +47,11 @@ _RERANK_MODELS: Optional[list["CrossEncoder"]] = None
|
||||
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
|
||||
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
|
||||
|
||||
# OpenAI only allows 2048 embeddings to be computed at once
|
||||
_OPENAI_MAX_INPUT_LEN = 2048
|
||||
# Cohere allows up to 96 embeddings in a single embedding calling
|
||||
_COHERE_MAX_INPUT_LEN = 96
|
||||
|
||||
|
||||
def _initialize_client(
|
||||
api_key: str, provider: EmbeddingProvider, model: str | None = None
|
||||
@ -88,9 +94,14 @@ class CloudEmbedding:
|
||||
# OpenAI does not seem to provide truncation option, however
|
||||
# the context lengths used by Danswer currently are smaller than the max token length
|
||||
# for OpenAI embeddings so it's not a big deal
|
||||
final_embeddings: list[Embedding] = []
|
||||
try:
|
||||
response = self.client.embeddings.create(input=texts, model=model)
|
||||
return [embedding.embedding for embedding in response.data]
|
||||
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
|
||||
response = self.client.embeddings.create(input=text_batch, model=model)
|
||||
final_embeddings.extend(
|
||||
[embedding.embedding for embedding in response.data]
|
||||
)
|
||||
return final_embeddings
|
||||
except Exception as e:
|
||||
error_string = (
|
||||
f"Error embedding text with OpenAI: {str(e)} \n"
|
||||
@ -107,15 +118,18 @@ class CloudEmbedding:
|
||||
if model is None:
|
||||
model = DEFAULT_COHERE_MODEL
|
||||
|
||||
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
|
||||
# empirically it's only off by a very few tokens so it's not a big deal
|
||||
response = self.client.embed(
|
||||
texts=texts,
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
truncate="END",
|
||||
)
|
||||
return response.embeddings
|
||||
final_embeddings: list[Embedding] = []
|
||||
for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN):
|
||||
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
|
||||
# empirically it's only off by a very few tokens so it's not a big deal
|
||||
response = self.client.embed(
|
||||
texts=text_batch,
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
truncate="END",
|
||||
)
|
||||
final_embeddings.extend(response.embeddings)
|
||||
return final_embeddings
|
||||
|
||||
def _embed_voyage(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
|
11
backend/shared_configs/utils.py
Normal file
11
backend/shared_configs/utils.py
Normal file
@ -0,0 +1,11 @@
|
||||
from typing import TypeVar
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def batch_list(
|
||||
lst: list[T],
|
||||
batch_size: int,
|
||||
) -> list[list[T]]:
|
||||
return [lst[i : i + batch_size] for i in range(0, len(lst), batch_size)]
|
77
backend/tests/integration/embedding/test_embeddings.py
Normal file
77
backend/tests/integration/embedding/test_embeddings.py
Normal file
@ -0,0 +1,77 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from shared_configs.enums import EmbedTextType
|
||||
|
||||
VALID_SAMPLE = ["hi", "hello my name is bob", "woah there!!!. 😃"]
|
||||
# openai limit is 2048, cohere is supposed to be 96 but in practice that doesn't
|
||||
# seem to be true
|
||||
TOO_LONG_SAMPLE = ["a"] * 2500
|
||||
|
||||
|
||||
def _run_embeddings(
|
||||
texts: list[str], embedding_model: EmbeddingModel, expected_dim: int
|
||||
) -> None:
|
||||
for text_type in [EmbedTextType.QUERY, EmbedTextType.PASSAGE]:
|
||||
embeddings = embedding_model.encode(texts, text_type)
|
||||
assert len(embeddings) == len(texts)
|
||||
assert len(embeddings[0]) == expected_dim
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_embedding_model() -> EmbeddingModel:
|
||||
return EmbeddingModel(
|
||||
server_host="localhost",
|
||||
server_port=9000,
|
||||
model_name="text-embedding-3-small",
|
||||
normalize=True,
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
provider_type="openai",
|
||||
)
|
||||
|
||||
|
||||
def test_openai_embedding(openai_embedding_model: EmbeddingModel) -> None:
|
||||
_run_embeddings(VALID_SAMPLE, openai_embedding_model, 1536)
|
||||
_run_embeddings(TOO_LONG_SAMPLE, openai_embedding_model, 1536)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cohere_embedding_model() -> EmbeddingModel:
|
||||
return EmbeddingModel(
|
||||
server_host="localhost",
|
||||
server_port=9000,
|
||||
model_name="embed-english-light-v3.0",
|
||||
normalize=True,
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
api_key=os.getenv("COHERE_API_KEY"),
|
||||
provider_type="cohere",
|
||||
)
|
||||
|
||||
|
||||
def test_cohere_embedding(cohere_embedding_model: EmbeddingModel) -> None:
|
||||
_run_embeddings(VALID_SAMPLE, cohere_embedding_model, 384)
|
||||
_run_embeddings(TOO_LONG_SAMPLE, cohere_embedding_model, 384)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_nomic_embedding_model() -> EmbeddingModel:
|
||||
return EmbeddingModel(
|
||||
server_host="localhost",
|
||||
server_port=9000,
|
||||
model_name="nomic-ai/nomic-embed-text-v1",
|
||||
normalize=True,
|
||||
query_prefix="search_query: ",
|
||||
passage_prefix="search_document: ",
|
||||
api_key=None,
|
||||
provider_type=None,
|
||||
)
|
||||
|
||||
|
||||
def test_local_nomic_embedding(local_nomic_embedding_model: EmbeddingModel) -> None:
|
||||
_run_embeddings(VALID_SAMPLE, local_nomic_embedding_model, 768)
|
||||
_run_embeddings(TOO_LONG_SAMPLE, local_nomic_embedding_model, 768)
|
Loading…
x
Reference in New Issue
Block a user