diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index 0619757b8..4a3466f0b 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -38,6 +38,8 @@ ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "search_query: ") ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "search_document: ") # Purely an optimization, memory limitation consideration BATCH_SIZE_ENCODE_CHUNKS = 8 +# don't send over too many chunks at once, as sending too many could cause timeouts +BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES = 512 # For score display purposes, only way is to know the expected ranges CROSS_ENCODER_RANGE_MAX = 1 CROSS_ENCODER_RANGE_MIN = 0 diff --git a/backend/danswer/natural_language_processing/search_nlp_models.py b/backend/danswer/natural_language_processing/search_nlp_models.py index da17ca4f2..bf208fff8 100644 --- a/backend/danswer/natural_language_processing/search_nlp_models.py +++ b/backend/danswer/natural_language_processing/search_nlp_models.py @@ -4,11 +4,13 @@ import requests from httpx import HTTPError from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS +from danswer.configs.model_configs import ( + BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES, +) from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.db.models import EmbeddingModel as DBEmbeddingModel from danswer.natural_language_processing.utils import get_tokenizer from danswer.natural_language_processing.utils import tokenizer_trim_content -from danswer.utils.batching import batch_list from danswer.utils.logger import setup_logger from shared_configs.configs import MODEL_SERVER_HOST from shared_configs.configs import MODEL_SERVER_PORT @@ -20,6 +22,7 @@ from shared_configs.model_server_models import IntentRequest from shared_configs.model_server_models import IntentResponse from shared_configs.model_server_models import RerankRequest from shared_configs.model_server_models import RerankResponse +from shared_configs.utils import batch_list logger = setup_logger() @@ -73,7 +76,8 @@ class EmbeddingModel: self, texts: list[str], text_type: EmbedTextType, - batch_size: int = BATCH_SIZE_ENCODE_CHUNKS, + local_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS, + api_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES, ) -> list[Embedding]: if not texts or not all(texts): raise ValueError(f"Empty or missing text for embedding: {texts}") @@ -95,6 +99,7 @@ class EmbeddingModel: ] if self.provider_type: + text_batches = batch_list(texts, api_embedding_batch_size) embed_request = EmbedRequest( model_name=self.model_name, texts=texts, @@ -120,7 +125,7 @@ class EmbeddingModel: return EmbedResponse(**response.json()).embeddings # Batching for local embedding - text_batches = batch_list(texts, batch_size) + text_batches = batch_list(texts, local_embedding_batch_size) embeddings: list[Embedding] = [] logger.debug( f"Encoding {len(texts)} texts in {len(text_batches)} batches for local model" diff --git a/backend/danswer/utils/batching.py b/backend/danswer/utils/batching.py index 2ea436e11..0200f7225 100644 --- a/backend/danswer/utils/batching.py +++ b/backend/danswer/utils/batching.py @@ -21,10 +21,3 @@ def batch_generator( if pre_batch_yield: pre_batch_yield(batch) yield batch - - -def batch_list( - lst: list[T], - batch_size: int, -) -> list[list[T]]: - return [lst[i : i + batch_size] for i in range(0, len(lst), batch_size)] diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index 350d4c222..f12aeb89b 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -34,6 +34,7 @@ from shared_configs.model_server_models import EmbedRequest from shared_configs.model_server_models import EmbedResponse from shared_configs.model_server_models import RerankRequest from shared_configs.model_server_models import RerankResponse +from shared_configs.utils import batch_list logger = setup_logger() @@ -46,6 +47,11 @@ _RERANK_MODELS: Optional[list["CrossEncoder"]] = None _RETRY_DELAY = 10 if INDEXING_ONLY else 0.1 _RETRY_TRIES = 10 if INDEXING_ONLY else 2 +# OpenAI only allows 2048 embeddings to be computed at once +_OPENAI_MAX_INPUT_LEN = 2048 +# Cohere allows up to 96 embeddings in a single embedding calling +_COHERE_MAX_INPUT_LEN = 96 + def _initialize_client( api_key: str, provider: EmbeddingProvider, model: str | None = None @@ -88,9 +94,14 @@ class CloudEmbedding: # OpenAI does not seem to provide truncation option, however # the context lengths used by Danswer currently are smaller than the max token length # for OpenAI embeddings so it's not a big deal + final_embeddings: list[Embedding] = [] try: - response = self.client.embeddings.create(input=texts, model=model) - return [embedding.embedding for embedding in response.data] + for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN): + response = self.client.embeddings.create(input=text_batch, model=model) + final_embeddings.extend( + [embedding.embedding for embedding in response.data] + ) + return final_embeddings except Exception as e: error_string = ( f"Error embedding text with OpenAI: {str(e)} \n" @@ -107,15 +118,18 @@ class CloudEmbedding: if model is None: model = DEFAULT_COHERE_MODEL - # Does not use the same tokenizer as the Danswer API server but it's approximately the same - # empirically it's only off by a very few tokens so it's not a big deal - response = self.client.embed( - texts=texts, - model=model, - input_type=embedding_type, - truncate="END", - ) - return response.embeddings + final_embeddings: list[Embedding] = [] + for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN): + # Does not use the same tokenizer as the Danswer API server but it's approximately the same + # empirically it's only off by a very few tokens so it's not a big deal + response = self.client.embed( + texts=text_batch, + model=model, + input_type=embedding_type, + truncate="END", + ) + final_embeddings.extend(response.embeddings) + return final_embeddings def _embed_voyage( self, texts: list[str], model: str | None, embedding_type: str diff --git a/backend/shared_configs/utils.py b/backend/shared_configs/utils.py new file mode 100644 index 000000000..c40795eb4 --- /dev/null +++ b/backend/shared_configs/utils.py @@ -0,0 +1,11 @@ +from typing import TypeVar + + +T = TypeVar("T") + + +def batch_list( + lst: list[T], + batch_size: int, +) -> list[list[T]]: + return [lst[i : i + batch_size] for i in range(0, len(lst), batch_size)] diff --git a/backend/tests/integration/embedding/test_embeddings.py b/backend/tests/integration/embedding/test_embeddings.py new file mode 100644 index 000000000..ce056477d --- /dev/null +++ b/backend/tests/integration/embedding/test_embeddings.py @@ -0,0 +1,77 @@ +import os + +import pytest + +from danswer.natural_language_processing.search_nlp_models import EmbeddingModel +from shared_configs.enums import EmbedTextType + +VALID_SAMPLE = ["hi", "hello my name is bob", "woah there!!!. 😃"] +# openai limit is 2048, cohere is supposed to be 96 but in practice that doesn't +# seem to be true +TOO_LONG_SAMPLE = ["a"] * 2500 + + +def _run_embeddings( + texts: list[str], embedding_model: EmbeddingModel, expected_dim: int +) -> None: + for text_type in [EmbedTextType.QUERY, EmbedTextType.PASSAGE]: + embeddings = embedding_model.encode(texts, text_type) + assert len(embeddings) == len(texts) + assert len(embeddings[0]) == expected_dim + + +@pytest.fixture +def openai_embedding_model() -> EmbeddingModel: + return EmbeddingModel( + server_host="localhost", + server_port=9000, + model_name="text-embedding-3-small", + normalize=True, + query_prefix=None, + passage_prefix=None, + api_key=os.getenv("OPENAI_API_KEY"), + provider_type="openai", + ) + + +def test_openai_embedding(openai_embedding_model: EmbeddingModel) -> None: + _run_embeddings(VALID_SAMPLE, openai_embedding_model, 1536) + _run_embeddings(TOO_LONG_SAMPLE, openai_embedding_model, 1536) + + +@pytest.fixture +def cohere_embedding_model() -> EmbeddingModel: + return EmbeddingModel( + server_host="localhost", + server_port=9000, + model_name="embed-english-light-v3.0", + normalize=True, + query_prefix=None, + passage_prefix=None, + api_key=os.getenv("COHERE_API_KEY"), + provider_type="cohere", + ) + + +def test_cohere_embedding(cohere_embedding_model: EmbeddingModel) -> None: + _run_embeddings(VALID_SAMPLE, cohere_embedding_model, 384) + _run_embeddings(TOO_LONG_SAMPLE, cohere_embedding_model, 384) + + +@pytest.fixture +def local_nomic_embedding_model() -> EmbeddingModel: + return EmbeddingModel( + server_host="localhost", + server_port=9000, + model_name="nomic-ai/nomic-embed-text-v1", + normalize=True, + query_prefix="search_query: ", + passage_prefix="search_document: ", + api_key=None, + provider_type=None, + ) + + +def test_local_nomic_embedding(local_nomic_embedding_model: EmbeddingModel) -> None: + _run_embeddings(VALID_SAMPLE, local_nomic_embedding_model, 768) + _run_embeddings(TOO_LONG_SAMPLE, local_nomic_embedding_model, 768)