diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index f252b29d172..003953cb29a 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -1,5 +1,5 @@ import json -from typing import Any +from typing import cast from typing import Optional import httpx @@ -25,6 +25,7 @@ from model_server.constants import DEFAULT_VOYAGE_MODEL from model_server.constants import EmbeddingModelTextType from model_server.constants import EmbeddingProvider from model_server.utils import simple_log_function_time +from shared_configs.configs import API_BASED_EMBEDDING_TIMEOUT from shared_configs.configs import INDEXING_ONLY from shared_configs.configs import OPENAI_EMBEDDING_TIMEOUT from shared_configs.enums import EmbedTextType @@ -54,32 +55,6 @@ _OPENAI_MAX_INPUT_LEN = 2048 _COHERE_MAX_INPUT_LEN = 96 -def _initialize_client( - api_key: str, - provider: EmbeddingProvider, - model: str | None = None, - api_url: str | None = None, - api_version: str | None = None, -) -> Any: - if provider == EmbeddingProvider.OPENAI: - return openai.OpenAI(api_key=api_key, timeout=OPENAI_EMBEDDING_TIMEOUT) - elif provider == EmbeddingProvider.COHERE: - return CohereClient(api_key=api_key) - elif provider == EmbeddingProvider.VOYAGE: - return voyageai.Client(api_key=api_key) - elif provider == EmbeddingProvider.GOOGLE: - credentials = service_account.Credentials.from_service_account_info( - json.loads(api_key) - ) - project_id = json.loads(api_key)["project_id"] - vertexai.init(project=project_id, credentials=credentials) - return TextEmbeddingModel.from_pretrained(model or DEFAULT_VERTEX_MODEL) - elif provider == EmbeddingProvider.AZURE: - return {"api_key": api_key, "api_url": api_url, "api_version": api_version} - else: - raise ValueError(f"Unsupported provider: {provider}") - - class CloudEmbedding: def __init__( self, @@ -87,25 +62,22 @@ class CloudEmbedding: provider: EmbeddingProvider, api_url: str | None = None, api_version: str | None = None, - # Only for Google as is needed on client setup - model: str | None = None, ) -> None: self.provider = provider - self.client = _initialize_client( - api_key, self.provider, model, api_url, api_version - ) + self.api_key = api_key + self.api_url = api_url + self.api_version = api_version def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]: if not model: model = DEFAULT_OPENAI_MODEL - # OpenAI does not seem to provide truncation option, however - # the context lengths used by Danswer currently are smaller than the max token length - # for OpenAI embeddings so it's not a big deal + client = openai.OpenAI(api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT) + final_embeddings: list[Embedding] = [] try: for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN): - response = self.client.embeddings.create(input=text_batch, model=model) + response = client.embeddings.create(input=text_batch, model=model) final_embeddings.extend( [embedding.embedding for embedding in response.data] ) @@ -126,17 +98,19 @@ class CloudEmbedding: if not model: model = DEFAULT_COHERE_MODEL + client = CohereClient(api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT) + final_embeddings: list[Embedding] = [] for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN): # Does not use the same tokenizer as the Danswer API server but it's approximately the same # empirically it's only off by a very few tokens so it's not a big deal - response = self.client.embed( + response = client.embed( texts=text_batch, model=model, input_type=embedding_type, truncate="END", ) - final_embeddings.extend(response.embeddings) + final_embeddings.extend(cast(list[Embedding], response.embeddings)) return final_embeddings def _embed_voyage( @@ -145,13 +119,15 @@ class CloudEmbedding: if not model: model = DEFAULT_VOYAGE_MODEL - # Similar to Cohere, the API server will do approximate size chunking - # it's acceptable to miss by a few tokens - response = self.client.embed( + client = voyageai.Client( + api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT + ) + + response = client.embed( texts, model=model, input_type=embedding_type, - truncation=True, # Also this is default + truncation=True, ) return response.embeddings @@ -159,9 +135,10 @@ class CloudEmbedding: response = embedding( model=model, input=texts, - api_key=self.client["api_key"], - api_base=self.client["api_url"], - api_version=self.client["api_version"], + timeout=API_BASED_EMBEDDING_TIMEOUT, + api_key=self.api_key, + api_base=self.api_url, + api_version=self.api_version, ) embeddings = [embedding["embedding"] for embedding in response.data] @@ -173,7 +150,14 @@ class CloudEmbedding: if not model: model = DEFAULT_VERTEX_MODEL - embeddings = self.client.get_embeddings( + credentials = service_account.Credentials.from_service_account_info( + json.loads(self.api_key) + ) + project_id = json.loads(self.api_key)["project_id"] + vertexai.init(project=project_id, credentials=credentials) + client = TextEmbeddingModel.from_pretrained(model) + + embeddings = client.get_embeddings( [ TextEmbeddingInput( text, @@ -185,6 +169,33 @@ class CloudEmbedding: ) return [embedding.values for embedding in embeddings] + def _embed_litellm_proxy( + self, texts: list[str], model_name: str | None + ) -> list[Embedding]: + if not model_name: + raise ValueError("Model name is required for LiteLLM proxy embedding.") + + if not self.api_url: + raise ValueError("API URL is required for LiteLLM proxy embedding.") + + headers = ( + {} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"} + ) + + with httpx.Client() as client: + response = client.post( + self.api_url, + json={ + "model": model_name, + "input": texts, + }, + headers=headers, + timeout=API_BASED_EMBEDDING_TIMEOUT, + ) + response.raise_for_status() + result = response.json() + return [embedding["embedding"] for embedding in result["data"]] + @retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY) def embed( self, @@ -199,6 +210,9 @@ class CloudEmbedding: return self._embed_openai(texts, model_name) elif self.provider == EmbeddingProvider.AZURE: return self._embed_azure(texts, f"azure/{deployment_name}") + elif self.provider == EmbeddingProvider.LITELLM: + return self._embed_litellm_proxy(texts, model_name) + embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type) if self.provider == EmbeddingProvider.COHERE: return self._embed_cohere(texts, model_name, embedding_type) @@ -218,12 +232,11 @@ class CloudEmbedding: def create( api_key: str, provider: EmbeddingProvider, - model: str | None = None, api_url: str | None = None, api_version: str | None = None, ) -> "CloudEmbedding": logger.debug(f"Creating Embedding instance for provider: {provider}") - return CloudEmbedding(api_key, provider, model, api_url, api_version) + return CloudEmbedding(api_key, provider, api_url, api_version) def get_embedding_model( @@ -266,25 +279,6 @@ def get_local_reranking_model( return _RERANK_MODEL -def embed_with_litellm_proxy( - texts: list[str], api_url: str, model_name: str, api_key: str | None -) -> list[Embedding]: - headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"} - - with httpx.Client() as client: - response = client.post( - api_url, - json={ - "model": model_name, - "input": texts, - }, - headers=headers, - ) - response.raise_for_status() - result = response.json() - return [embedding["embedding"] for embedding in result["data"]] - - @simple_log_function_time() def embed_text( texts: list[str], @@ -309,23 +303,7 @@ def embed_text( logger.error("No texts provided for embedding") raise ValueError("No texts provided for embedding.") - if provider_type == EmbeddingProvider.LITELLM: - logger.debug(f"Using LiteLLM proxy for embedding with URL: {api_url}") - if not api_url: - logger.error("API URL not provided for LiteLLM proxy") - raise ValueError("API URL is required for LiteLLM proxy embedding.") - try: - return embed_with_litellm_proxy( - texts=texts, - api_url=api_url, - model_name=model_name or "", - api_key=api_key, - ) - except Exception as e: - logger.exception(f"Error during LiteLLM proxy embedding: {str(e)}") - raise - - elif provider_type is not None: + if provider_type is not None: logger.debug(f"Using cloud provider {provider_type} for embedding") if api_key is None: logger.error("API key not provided for cloud model") @@ -341,7 +319,6 @@ def embed_text( cloud_model = CloudEmbedding( api_key=api_key, provider=provider_type, - model=model_name, api_url=api_url, api_version=api_version, ) diff --git a/backend/shared_configs/configs.py b/backend/shared_configs/configs.py index ca452640071..f10855f103f 100644 --- a/backend/shared_configs/configs.py +++ b/backend/shared_configs/configs.py @@ -63,8 +63,15 @@ DEV_LOGGING_ENABLED = os.environ.get("DEV_LOGGING_ENABLED", "").lower() == "true # notset, debug, info, notice, warning, error, or critical LOG_LEVEL = os.environ.get("LOG_LEVEL", "notice") +# Timeout for API-based embedding models +# NOTE: does not apply for Google VertexAI, since the python client doesn't +# allow us to specify a custom timeout +API_BASED_EMBEDDING_TIMEOUT = int(os.environ.get("API_BASED_EMBEDDING_TIMEOUT", "600")) + # Only used for OpenAI -OPENAI_EMBEDDING_TIMEOUT = int(os.environ.get("OPENAI_EMBEDDING_TIMEOUT", "600")) +OPENAI_EMBEDDING_TIMEOUT = int( + os.environ.get("OPENAI_EMBEDDING_TIMEOUT", API_BASED_EMBEDDING_TIMEOUT) +) # Whether or not to strictly enforce token limit for chunking. STRICT_CHUNK_TOKEN_LIMIT = ( diff --git a/backend/tests/daily/embedding/test_embeddings.py b/backend/tests/daily/embedding/test_embeddings.py index b736f374741..10a1dd850f6 100644 --- a/backend/tests/daily/embedding/test_embeddings.py +++ b/backend/tests/daily/embedding/test_embeddings.py @@ -61,6 +61,26 @@ def test_cohere_embedding(cohere_embedding_model: EmbeddingModel) -> None: _run_embeddings(TOO_LONG_SAMPLE, cohere_embedding_model, 384) +@pytest.fixture +def litellm_embedding_model() -> EmbeddingModel: + return EmbeddingModel( + server_host="localhost", + server_port=9000, + model_name="text-embedding-3-small", + normalize=True, + query_prefix=None, + passage_prefix=None, + api_key=os.getenv("LITE_LLM_API_KEY"), + provider_type=EmbeddingProvider.LITELLM, + api_url=os.getenv("LITE_LLM_API_URL"), + ) + + +def test_litellm_embedding(litellm_embedding_model: EmbeddingModel) -> None: + _run_embeddings(VALID_SAMPLE, litellm_embedding_model, 1536) + _run_embeddings(TOO_LONG_SAMPLE, litellm_embedding_model, 1536) + + @pytest.fixture def local_nomic_embedding_model() -> EmbeddingModel: return EmbeddingModel(