Refactor + add global timeout env variable (#2844)

* Refactor + add global timeout env variable

* remove model

* mypy

* Remove unused
This commit is contained in:
Chris Weaver 2024-10-18 11:25:27 -07:00 committed by GitHub
parent 5b78299880
commit 36134021c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 90 additions and 86 deletions

View File

@ -1,5 +1,5 @@
import json
from typing import Any
from typing import cast
from typing import Optional
import httpx
@ -25,6 +25,7 @@ from model_server.constants import DEFAULT_VOYAGE_MODEL
from model_server.constants import EmbeddingModelTextType
from model_server.constants import EmbeddingProvider
from model_server.utils import simple_log_function_time
from shared_configs.configs import API_BASED_EMBEDDING_TIMEOUT
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import OPENAI_EMBEDDING_TIMEOUT
from shared_configs.enums import EmbedTextType
@ -54,32 +55,6 @@ _OPENAI_MAX_INPUT_LEN = 2048
_COHERE_MAX_INPUT_LEN = 96
def _initialize_client(
api_key: str,
provider: EmbeddingProvider,
model: str | None = None,
api_url: str | None = None,
api_version: str | None = None,
) -> Any:
if provider == EmbeddingProvider.OPENAI:
return openai.OpenAI(api_key=api_key, timeout=OPENAI_EMBEDDING_TIMEOUT)
elif provider == EmbeddingProvider.COHERE:
return CohereClient(api_key=api_key)
elif provider == EmbeddingProvider.VOYAGE:
return voyageai.Client(api_key=api_key)
elif provider == EmbeddingProvider.GOOGLE:
credentials = service_account.Credentials.from_service_account_info(
json.loads(api_key)
)
project_id = json.loads(api_key)["project_id"]
vertexai.init(project=project_id, credentials=credentials)
return TextEmbeddingModel.from_pretrained(model or DEFAULT_VERTEX_MODEL)
elif provider == EmbeddingProvider.AZURE:
return {"api_key": api_key, "api_url": api_url, "api_version": api_version}
else:
raise ValueError(f"Unsupported provider: {provider}")
class CloudEmbedding:
def __init__(
self,
@ -87,25 +62,22 @@ class CloudEmbedding:
provider: EmbeddingProvider,
api_url: str | None = None,
api_version: str | None = None,
# Only for Google as is needed on client setup
model: str | None = None,
) -> None:
self.provider = provider
self.client = _initialize_client(
api_key, self.provider, model, api_url, api_version
)
self.api_key = api_key
self.api_url = api_url
self.api_version = api_version
def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
if not model:
model = DEFAULT_OPENAI_MODEL
# OpenAI does not seem to provide truncation option, however
# the context lengths used by Danswer currently are smaller than the max token length
# for OpenAI embeddings so it's not a big deal
client = openai.OpenAI(api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT)
final_embeddings: list[Embedding] = []
try:
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = self.client.embeddings.create(input=text_batch, model=model)
response = client.embeddings.create(input=text_batch, model=model)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
@ -126,17 +98,19 @@ class CloudEmbedding:
if not model:
model = DEFAULT_COHERE_MODEL
client = CohereClient(api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT)
final_embeddings: list[Embedding] = []
for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN):
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
# empirically it's only off by a very few tokens so it's not a big deal
response = self.client.embed(
response = client.embed(
texts=text_batch,
model=model,
input_type=embedding_type,
truncate="END",
)
final_embeddings.extend(response.embeddings)
final_embeddings.extend(cast(list[Embedding], response.embeddings))
return final_embeddings
def _embed_voyage(
@ -145,13 +119,15 @@ class CloudEmbedding:
if not model:
model = DEFAULT_VOYAGE_MODEL
# Similar to Cohere, the API server will do approximate size chunking
# it's acceptable to miss by a few tokens
response = self.client.embed(
client = voyageai.Client(
api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT
)
response = client.embed(
texts,
model=model,
input_type=embedding_type,
truncation=True, # Also this is default
truncation=True,
)
return response.embeddings
@ -159,9 +135,10 @@ class CloudEmbedding:
response = embedding(
model=model,
input=texts,
api_key=self.client["api_key"],
api_base=self.client["api_url"],
api_version=self.client["api_version"],
timeout=API_BASED_EMBEDDING_TIMEOUT,
api_key=self.api_key,
api_base=self.api_url,
api_version=self.api_version,
)
embeddings = [embedding["embedding"] for embedding in response.data]
@ -173,7 +150,14 @@ class CloudEmbedding:
if not model:
model = DEFAULT_VERTEX_MODEL
embeddings = self.client.get_embeddings(
credentials = service_account.Credentials.from_service_account_info(
json.loads(self.api_key)
)
project_id = json.loads(self.api_key)["project_id"]
vertexai.init(project=project_id, credentials=credentials)
client = TextEmbeddingModel.from_pretrained(model)
embeddings = client.get_embeddings(
[
TextEmbeddingInput(
text,
@ -185,6 +169,33 @@ class CloudEmbedding:
)
return [embedding.values for embedding in embeddings]
def _embed_litellm_proxy(
self, texts: list[str], model_name: str | None
) -> list[Embedding]:
if not model_name:
raise ValueError("Model name is required for LiteLLM proxy embedding.")
if not self.api_url:
raise ValueError("API URL is required for LiteLLM proxy embedding.")
headers = (
{} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"}
)
with httpx.Client() as client:
response = client.post(
self.api_url,
json={
"model": model_name,
"input": texts,
},
headers=headers,
timeout=API_BASED_EMBEDDING_TIMEOUT,
)
response.raise_for_status()
result = response.json()
return [embedding["embedding"] for embedding in result["data"]]
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
def embed(
self,
@ -199,6 +210,9 @@ class CloudEmbedding:
return self._embed_openai(texts, model_name)
elif self.provider == EmbeddingProvider.AZURE:
return self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return self._embed_litellm_proxy(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(texts, model_name, embedding_type)
@ -218,12 +232,11 @@ class CloudEmbedding:
def create(
api_key: str,
provider: EmbeddingProvider,
model: str | None = None,
api_url: str | None = None,
api_version: str | None = None,
) -> "CloudEmbedding":
logger.debug(f"Creating Embedding instance for provider: {provider}")
return CloudEmbedding(api_key, provider, model, api_url, api_version)
return CloudEmbedding(api_key, provider, api_url, api_version)
def get_embedding_model(
@ -266,25 +279,6 @@ def get_local_reranking_model(
return _RERANK_MODEL
def embed_with_litellm_proxy(
texts: list[str], api_url: str, model_name: str, api_key: str | None
) -> list[Embedding]:
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
with httpx.Client() as client:
response = client.post(
api_url,
json={
"model": model_name,
"input": texts,
},
headers=headers,
)
response.raise_for_status()
result = response.json()
return [embedding["embedding"] for embedding in result["data"]]
@simple_log_function_time()
def embed_text(
texts: list[str],
@ -309,23 +303,7 @@ def embed_text(
logger.error("No texts provided for embedding")
raise ValueError("No texts provided for embedding.")
if provider_type == EmbeddingProvider.LITELLM:
logger.debug(f"Using LiteLLM proxy for embedding with URL: {api_url}")
if not api_url:
logger.error("API URL not provided for LiteLLM proxy")
raise ValueError("API URL is required for LiteLLM proxy embedding.")
try:
return embed_with_litellm_proxy(
texts=texts,
api_url=api_url,
model_name=model_name or "",
api_key=api_key,
)
except Exception as e:
logger.exception(f"Error during LiteLLM proxy embedding: {str(e)}")
raise
elif provider_type is not None:
if provider_type is not None:
logger.debug(f"Using cloud provider {provider_type} for embedding")
if api_key is None:
logger.error("API key not provided for cloud model")
@ -341,7 +319,6 @@ def embed_text(
cloud_model = CloudEmbedding(
api_key=api_key,
provider=provider_type,
model=model_name,
api_url=api_url,
api_version=api_version,
)

View File

@ -63,8 +63,15 @@ DEV_LOGGING_ENABLED = os.environ.get("DEV_LOGGING_ENABLED", "").lower() == "true
# notset, debug, info, notice, warning, error, or critical
LOG_LEVEL = os.environ.get("LOG_LEVEL", "notice")
# Timeout for API-based embedding models
# NOTE: does not apply for Google VertexAI, since the python client doesn't
# allow us to specify a custom timeout
API_BASED_EMBEDDING_TIMEOUT = int(os.environ.get("API_BASED_EMBEDDING_TIMEOUT", "600"))
# Only used for OpenAI
OPENAI_EMBEDDING_TIMEOUT = int(os.environ.get("OPENAI_EMBEDDING_TIMEOUT", "600"))
OPENAI_EMBEDDING_TIMEOUT = int(
os.environ.get("OPENAI_EMBEDDING_TIMEOUT", API_BASED_EMBEDDING_TIMEOUT)
)
# Whether or not to strictly enforce token limit for chunking.
STRICT_CHUNK_TOKEN_LIMIT = (

View File

@ -61,6 +61,26 @@ def test_cohere_embedding(cohere_embedding_model: EmbeddingModel) -> None:
_run_embeddings(TOO_LONG_SAMPLE, cohere_embedding_model, 384)
@pytest.fixture
def litellm_embedding_model() -> EmbeddingModel:
return EmbeddingModel(
server_host="localhost",
server_port=9000,
model_name="text-embedding-3-small",
normalize=True,
query_prefix=None,
passage_prefix=None,
api_key=os.getenv("LITE_LLM_API_KEY"),
provider_type=EmbeddingProvider.LITELLM,
api_url=os.getenv("LITE_LLM_API_URL"),
)
def test_litellm_embedding(litellm_embedding_model: EmbeddingModel) -> None:
_run_embeddings(VALID_SAMPLE, litellm_embedding_model, 1536)
_run_embeddings(TOO_LONG_SAMPLE, litellm_embedding_model, 1536)
@pytest.fixture
def local_nomic_embedding_model() -> EmbeddingModel:
return EmbeddingModel(