Fix issue where large docs/batches break openai embedding

This commit is contained in:
Weves 2024-08-02 00:57:50 -07:00 committed by Chris Weaver
parent f280586e68
commit 51731ad0dd
6 changed files with 123 additions and 21 deletions

View File

@ -38,6 +38,8 @@ ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "search_query: ")
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "search_document: ")
# Purely an optimization, memory limitation consideration
BATCH_SIZE_ENCODE_CHUNKS = 8
# don't send over too many chunks at once, as sending too many could cause timeouts
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES = 512
# For score display purposes, only way is to know the expected ranges
CROSS_ENCODER_RANGE_MAX = 1
CROSS_ENCODER_RANGE_MIN = 0

View File

@ -4,11 +4,13 @@ import requests
from httpx import HTTPError
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
from danswer.configs.model_configs import (
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES,
)
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.models import EmbeddingModel as DBEmbeddingModel
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.batching import batch_list
from danswer.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
@ -20,6 +22,7 @@ from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
from shared_configs.model_server_models import RerankRequest
from shared_configs.model_server_models import RerankResponse
from shared_configs.utils import batch_list
logger = setup_logger()
@ -73,7 +76,8 @@ class EmbeddingModel:
self,
texts: list[str],
text_type: EmbedTextType,
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
local_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
api_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES,
) -> list[Embedding]:
if not texts or not all(texts):
raise ValueError(f"Empty or missing text for embedding: {texts}")
@ -95,6 +99,7 @@ class EmbeddingModel:
]
if self.provider_type:
text_batches = batch_list(texts, api_embedding_batch_size)
embed_request = EmbedRequest(
model_name=self.model_name,
texts=texts,
@ -120,7 +125,7 @@ class EmbeddingModel:
return EmbedResponse(**response.json()).embeddings
# Batching for local embedding
text_batches = batch_list(texts, batch_size)
text_batches = batch_list(texts, local_embedding_batch_size)
embeddings: list[Embedding] = []
logger.debug(
f"Encoding {len(texts)} texts in {len(text_batches)} batches for local model"

View File

@ -21,10 +21,3 @@ def batch_generator(
if pre_batch_yield:
pre_batch_yield(batch)
yield batch
def batch_list(
lst: list[T],
batch_size: int,
) -> list[list[T]]:
return [lst[i : i + batch_size] for i in range(0, len(lst), batch_size)]

View File

@ -34,6 +34,7 @@ from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import RerankRequest
from shared_configs.model_server_models import RerankResponse
from shared_configs.utils import batch_list
logger = setup_logger()
@ -46,6 +47,11 @@ _RERANK_MODELS: Optional[list["CrossEncoder"]] = None
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
# OpenAI only allows 2048 embeddings to be computed at once
_OPENAI_MAX_INPUT_LEN = 2048
# Cohere allows up to 96 embeddings in a single embedding calling
_COHERE_MAX_INPUT_LEN = 96
def _initialize_client(
api_key: str, provider: EmbeddingProvider, model: str | None = None
@ -88,9 +94,14 @@ class CloudEmbedding:
# OpenAI does not seem to provide truncation option, however
# the context lengths used by Danswer currently are smaller than the max token length
# for OpenAI embeddings so it's not a big deal
final_embeddings: list[Embedding] = []
try:
response = self.client.embeddings.create(input=texts, model=model)
return [embedding.embedding for embedding in response.data]
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = self.client.embeddings.create(input=text_batch, model=model)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
return final_embeddings
except Exception as e:
error_string = (
f"Error embedding text with OpenAI: {str(e)} \n"
@ -107,15 +118,18 @@ class CloudEmbedding:
if model is None:
model = DEFAULT_COHERE_MODEL
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
# empirically it's only off by a very few tokens so it's not a big deal
response = self.client.embed(
texts=texts,
model=model,
input_type=embedding_type,
truncate="END",
)
return response.embeddings
final_embeddings: list[Embedding] = []
for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN):
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
# empirically it's only off by a very few tokens so it's not a big deal
response = self.client.embed(
texts=text_batch,
model=model,
input_type=embedding_type,
truncate="END",
)
final_embeddings.extend(response.embeddings)
return final_embeddings
def _embed_voyage(
self, texts: list[str], model: str | None, embedding_type: str

View File

@ -0,0 +1,11 @@
from typing import TypeVar
T = TypeVar("T")
def batch_list(
lst: list[T],
batch_size: int,
) -> list[list[T]]:
return [lst[i : i + batch_size] for i in range(0, len(lst), batch_size)]

View File

@ -0,0 +1,77 @@
import os
import pytest
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from shared_configs.enums import EmbedTextType
VALID_SAMPLE = ["hi", "hello my name is bob", "woah there!!!. 😃"]
# openai limit is 2048, cohere is supposed to be 96 but in practice that doesn't
# seem to be true
TOO_LONG_SAMPLE = ["a"] * 2500
def _run_embeddings(
texts: list[str], embedding_model: EmbeddingModel, expected_dim: int
) -> None:
for text_type in [EmbedTextType.QUERY, EmbedTextType.PASSAGE]:
embeddings = embedding_model.encode(texts, text_type)
assert len(embeddings) == len(texts)
assert len(embeddings[0]) == expected_dim
@pytest.fixture
def openai_embedding_model() -> EmbeddingModel:
return EmbeddingModel(
server_host="localhost",
server_port=9000,
model_name="text-embedding-3-small",
normalize=True,
query_prefix=None,
passage_prefix=None,
api_key=os.getenv("OPENAI_API_KEY"),
provider_type="openai",
)
def test_openai_embedding(openai_embedding_model: EmbeddingModel) -> None:
_run_embeddings(VALID_SAMPLE, openai_embedding_model, 1536)
_run_embeddings(TOO_LONG_SAMPLE, openai_embedding_model, 1536)
@pytest.fixture
def cohere_embedding_model() -> EmbeddingModel:
return EmbeddingModel(
server_host="localhost",
server_port=9000,
model_name="embed-english-light-v3.0",
normalize=True,
query_prefix=None,
passage_prefix=None,
api_key=os.getenv("COHERE_API_KEY"),
provider_type="cohere",
)
def test_cohere_embedding(cohere_embedding_model: EmbeddingModel) -> None:
_run_embeddings(VALID_SAMPLE, cohere_embedding_model, 384)
_run_embeddings(TOO_LONG_SAMPLE, cohere_embedding_model, 384)
@pytest.fixture
def local_nomic_embedding_model() -> EmbeddingModel:
return EmbeddingModel(
server_host="localhost",
server_port=9000,
model_name="nomic-ai/nomic-embed-text-v1",
normalize=True,
query_prefix="search_query: ",
passage_prefix="search_document: ",
api_key=None,
provider_type=None,
)
def test_local_nomic_embedding(local_nomic_embedding_model: EmbeddingModel) -> None:
_run_embeddings(VALID_SAMPLE, local_nomic_embedding_model, 768)
_run_embeddings(TOO_LONG_SAMPLE, local_nomic_embedding_model, 768)