mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-06 04:59:24 +02:00
Refactor + add global timeout env variable (#2844)
* Refactor + add global timeout env variable * remove model * mypy * Remove unused
This commit is contained in:
parent
5b78299880
commit
36134021c5
@ -1,5 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import cast
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@ -25,6 +25,7 @@ from model_server.constants import DEFAULT_VOYAGE_MODEL
|
|||||||
from model_server.constants import EmbeddingModelTextType
|
from model_server.constants import EmbeddingModelTextType
|
||||||
from model_server.constants import EmbeddingProvider
|
from model_server.constants import EmbeddingProvider
|
||||||
from model_server.utils import simple_log_function_time
|
from model_server.utils import simple_log_function_time
|
||||||
|
from shared_configs.configs import API_BASED_EMBEDDING_TIMEOUT
|
||||||
from shared_configs.configs import INDEXING_ONLY
|
from shared_configs.configs import INDEXING_ONLY
|
||||||
from shared_configs.configs import OPENAI_EMBEDDING_TIMEOUT
|
from shared_configs.configs import OPENAI_EMBEDDING_TIMEOUT
|
||||||
from shared_configs.enums import EmbedTextType
|
from shared_configs.enums import EmbedTextType
|
||||||
@ -54,32 +55,6 @@ _OPENAI_MAX_INPUT_LEN = 2048
|
|||||||
_COHERE_MAX_INPUT_LEN = 96
|
_COHERE_MAX_INPUT_LEN = 96
|
||||||
|
|
||||||
|
|
||||||
def _initialize_client(
|
|
||||||
api_key: str,
|
|
||||||
provider: EmbeddingProvider,
|
|
||||||
model: str | None = None,
|
|
||||||
api_url: str | None = None,
|
|
||||||
api_version: str | None = None,
|
|
||||||
) -> Any:
|
|
||||||
if provider == EmbeddingProvider.OPENAI:
|
|
||||||
return openai.OpenAI(api_key=api_key, timeout=OPENAI_EMBEDDING_TIMEOUT)
|
|
||||||
elif provider == EmbeddingProvider.COHERE:
|
|
||||||
return CohereClient(api_key=api_key)
|
|
||||||
elif provider == EmbeddingProvider.VOYAGE:
|
|
||||||
return voyageai.Client(api_key=api_key)
|
|
||||||
elif provider == EmbeddingProvider.GOOGLE:
|
|
||||||
credentials = service_account.Credentials.from_service_account_info(
|
|
||||||
json.loads(api_key)
|
|
||||||
)
|
|
||||||
project_id = json.loads(api_key)["project_id"]
|
|
||||||
vertexai.init(project=project_id, credentials=credentials)
|
|
||||||
return TextEmbeddingModel.from_pretrained(model or DEFAULT_VERTEX_MODEL)
|
|
||||||
elif provider == EmbeddingProvider.AZURE:
|
|
||||||
return {"api_key": api_key, "api_url": api_url, "api_version": api_version}
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported provider: {provider}")
|
|
||||||
|
|
||||||
|
|
||||||
class CloudEmbedding:
|
class CloudEmbedding:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -87,25 +62,22 @@ class CloudEmbedding:
|
|||||||
provider: EmbeddingProvider,
|
provider: EmbeddingProvider,
|
||||||
api_url: str | None = None,
|
api_url: str | None = None,
|
||||||
api_version: str | None = None,
|
api_version: str | None = None,
|
||||||
# Only for Google as is needed on client setup
|
|
||||||
model: str | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.client = _initialize_client(
|
self.api_key = api_key
|
||||||
api_key, self.provider, model, api_url, api_version
|
self.api_url = api_url
|
||||||
)
|
self.api_version = api_version
|
||||||
|
|
||||||
def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
|
def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
|
||||||
if not model:
|
if not model:
|
||||||
model = DEFAULT_OPENAI_MODEL
|
model = DEFAULT_OPENAI_MODEL
|
||||||
|
|
||||||
# OpenAI does not seem to provide truncation option, however
|
client = openai.OpenAI(api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT)
|
||||||
# the context lengths used by Danswer currently are smaller than the max token length
|
|
||||||
# for OpenAI embeddings so it's not a big deal
|
|
||||||
final_embeddings: list[Embedding] = []
|
final_embeddings: list[Embedding] = []
|
||||||
try:
|
try:
|
||||||
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
|
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
|
||||||
response = self.client.embeddings.create(input=text_batch, model=model)
|
response = client.embeddings.create(input=text_batch, model=model)
|
||||||
final_embeddings.extend(
|
final_embeddings.extend(
|
||||||
[embedding.embedding for embedding in response.data]
|
[embedding.embedding for embedding in response.data]
|
||||||
)
|
)
|
||||||
@ -126,17 +98,19 @@ class CloudEmbedding:
|
|||||||
if not model:
|
if not model:
|
||||||
model = DEFAULT_COHERE_MODEL
|
model = DEFAULT_COHERE_MODEL
|
||||||
|
|
||||||
|
client = CohereClient(api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT)
|
||||||
|
|
||||||
final_embeddings: list[Embedding] = []
|
final_embeddings: list[Embedding] = []
|
||||||
for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN):
|
for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN):
|
||||||
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
|
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
|
||||||
# empirically it's only off by a very few tokens so it's not a big deal
|
# empirically it's only off by a very few tokens so it's not a big deal
|
||||||
response = self.client.embed(
|
response = client.embed(
|
||||||
texts=text_batch,
|
texts=text_batch,
|
||||||
model=model,
|
model=model,
|
||||||
input_type=embedding_type,
|
input_type=embedding_type,
|
||||||
truncate="END",
|
truncate="END",
|
||||||
)
|
)
|
||||||
final_embeddings.extend(response.embeddings)
|
final_embeddings.extend(cast(list[Embedding], response.embeddings))
|
||||||
return final_embeddings
|
return final_embeddings
|
||||||
|
|
||||||
def _embed_voyage(
|
def _embed_voyage(
|
||||||
@ -145,13 +119,15 @@ class CloudEmbedding:
|
|||||||
if not model:
|
if not model:
|
||||||
model = DEFAULT_VOYAGE_MODEL
|
model = DEFAULT_VOYAGE_MODEL
|
||||||
|
|
||||||
# Similar to Cohere, the API server will do approximate size chunking
|
client = voyageai.Client(
|
||||||
# it's acceptable to miss by a few tokens
|
api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT
|
||||||
response = self.client.embed(
|
)
|
||||||
|
|
||||||
|
response = client.embed(
|
||||||
texts,
|
texts,
|
||||||
model=model,
|
model=model,
|
||||||
input_type=embedding_type,
|
input_type=embedding_type,
|
||||||
truncation=True, # Also this is default
|
truncation=True,
|
||||||
)
|
)
|
||||||
return response.embeddings
|
return response.embeddings
|
||||||
|
|
||||||
@ -159,9 +135,10 @@ class CloudEmbedding:
|
|||||||
response = embedding(
|
response = embedding(
|
||||||
model=model,
|
model=model,
|
||||||
input=texts,
|
input=texts,
|
||||||
api_key=self.client["api_key"],
|
timeout=API_BASED_EMBEDDING_TIMEOUT,
|
||||||
api_base=self.client["api_url"],
|
api_key=self.api_key,
|
||||||
api_version=self.client["api_version"],
|
api_base=self.api_url,
|
||||||
|
api_version=self.api_version,
|
||||||
)
|
)
|
||||||
embeddings = [embedding["embedding"] for embedding in response.data]
|
embeddings = [embedding["embedding"] for embedding in response.data]
|
||||||
|
|
||||||
@ -173,7 +150,14 @@ class CloudEmbedding:
|
|||||||
if not model:
|
if not model:
|
||||||
model = DEFAULT_VERTEX_MODEL
|
model = DEFAULT_VERTEX_MODEL
|
||||||
|
|
||||||
embeddings = self.client.get_embeddings(
|
credentials = service_account.Credentials.from_service_account_info(
|
||||||
|
json.loads(self.api_key)
|
||||||
|
)
|
||||||
|
project_id = json.loads(self.api_key)["project_id"]
|
||||||
|
vertexai.init(project=project_id, credentials=credentials)
|
||||||
|
client = TextEmbeddingModel.from_pretrained(model)
|
||||||
|
|
||||||
|
embeddings = client.get_embeddings(
|
||||||
[
|
[
|
||||||
TextEmbeddingInput(
|
TextEmbeddingInput(
|
||||||
text,
|
text,
|
||||||
@ -185,6 +169,33 @@ class CloudEmbedding:
|
|||||||
)
|
)
|
||||||
return [embedding.values for embedding in embeddings]
|
return [embedding.values for embedding in embeddings]
|
||||||
|
|
||||||
|
def _embed_litellm_proxy(
|
||||||
|
self, texts: list[str], model_name: str | None
|
||||||
|
) -> list[Embedding]:
|
||||||
|
if not model_name:
|
||||||
|
raise ValueError("Model name is required for LiteLLM proxy embedding.")
|
||||||
|
|
||||||
|
if not self.api_url:
|
||||||
|
raise ValueError("API URL is required for LiteLLM proxy embedding.")
|
||||||
|
|
||||||
|
headers = (
|
||||||
|
{} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"}
|
||||||
|
)
|
||||||
|
|
||||||
|
with httpx.Client() as client:
|
||||||
|
response = client.post(
|
||||||
|
self.api_url,
|
||||||
|
json={
|
||||||
|
"model": model_name,
|
||||||
|
"input": texts,
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
timeout=API_BASED_EMBEDDING_TIMEOUT,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
return [embedding["embedding"] for embedding in result["data"]]
|
||||||
|
|
||||||
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
|
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
|
||||||
def embed(
|
def embed(
|
||||||
self,
|
self,
|
||||||
@ -199,6 +210,9 @@ class CloudEmbedding:
|
|||||||
return self._embed_openai(texts, model_name)
|
return self._embed_openai(texts, model_name)
|
||||||
elif self.provider == EmbeddingProvider.AZURE:
|
elif self.provider == EmbeddingProvider.AZURE:
|
||||||
return self._embed_azure(texts, f"azure/{deployment_name}")
|
return self._embed_azure(texts, f"azure/{deployment_name}")
|
||||||
|
elif self.provider == EmbeddingProvider.LITELLM:
|
||||||
|
return self._embed_litellm_proxy(texts, model_name)
|
||||||
|
|
||||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||||
if self.provider == EmbeddingProvider.COHERE:
|
if self.provider == EmbeddingProvider.COHERE:
|
||||||
return self._embed_cohere(texts, model_name, embedding_type)
|
return self._embed_cohere(texts, model_name, embedding_type)
|
||||||
@ -218,12 +232,11 @@ class CloudEmbedding:
|
|||||||
def create(
|
def create(
|
||||||
api_key: str,
|
api_key: str,
|
||||||
provider: EmbeddingProvider,
|
provider: EmbeddingProvider,
|
||||||
model: str | None = None,
|
|
||||||
api_url: str | None = None,
|
api_url: str | None = None,
|
||||||
api_version: str | None = None,
|
api_version: str | None = None,
|
||||||
) -> "CloudEmbedding":
|
) -> "CloudEmbedding":
|
||||||
logger.debug(f"Creating Embedding instance for provider: {provider}")
|
logger.debug(f"Creating Embedding instance for provider: {provider}")
|
||||||
return CloudEmbedding(api_key, provider, model, api_url, api_version)
|
return CloudEmbedding(api_key, provider, api_url, api_version)
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_model(
|
def get_embedding_model(
|
||||||
@ -266,25 +279,6 @@ def get_local_reranking_model(
|
|||||||
return _RERANK_MODEL
|
return _RERANK_MODEL
|
||||||
|
|
||||||
|
|
||||||
def embed_with_litellm_proxy(
|
|
||||||
texts: list[str], api_url: str, model_name: str, api_key: str | None
|
|
||||||
) -> list[Embedding]:
|
|
||||||
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
|
|
||||||
|
|
||||||
with httpx.Client() as client:
|
|
||||||
response = client.post(
|
|
||||||
api_url,
|
|
||||||
json={
|
|
||||||
"model": model_name,
|
|
||||||
"input": texts,
|
|
||||||
},
|
|
||||||
headers=headers,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
result = response.json()
|
|
||||||
return [embedding["embedding"] for embedding in result["data"]]
|
|
||||||
|
|
||||||
|
|
||||||
@simple_log_function_time()
|
@simple_log_function_time()
|
||||||
def embed_text(
|
def embed_text(
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
@ -309,23 +303,7 @@ def embed_text(
|
|||||||
logger.error("No texts provided for embedding")
|
logger.error("No texts provided for embedding")
|
||||||
raise ValueError("No texts provided for embedding.")
|
raise ValueError("No texts provided for embedding.")
|
||||||
|
|
||||||
if provider_type == EmbeddingProvider.LITELLM:
|
if provider_type is not None:
|
||||||
logger.debug(f"Using LiteLLM proxy for embedding with URL: {api_url}")
|
|
||||||
if not api_url:
|
|
||||||
logger.error("API URL not provided for LiteLLM proxy")
|
|
||||||
raise ValueError("API URL is required for LiteLLM proxy embedding.")
|
|
||||||
try:
|
|
||||||
return embed_with_litellm_proxy(
|
|
||||||
texts=texts,
|
|
||||||
api_url=api_url,
|
|
||||||
model_name=model_name or "",
|
|
||||||
api_key=api_key,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Error during LiteLLM proxy embedding: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
elif provider_type is not None:
|
|
||||||
logger.debug(f"Using cloud provider {provider_type} for embedding")
|
logger.debug(f"Using cloud provider {provider_type} for embedding")
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
logger.error("API key not provided for cloud model")
|
logger.error("API key not provided for cloud model")
|
||||||
@ -341,7 +319,6 @@ def embed_text(
|
|||||||
cloud_model = CloudEmbedding(
|
cloud_model = CloudEmbedding(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
provider=provider_type,
|
provider=provider_type,
|
||||||
model=model_name,
|
|
||||||
api_url=api_url,
|
api_url=api_url,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
)
|
)
|
||||||
|
@ -63,8 +63,15 @@ DEV_LOGGING_ENABLED = os.environ.get("DEV_LOGGING_ENABLED", "").lower() == "true
|
|||||||
# notset, debug, info, notice, warning, error, or critical
|
# notset, debug, info, notice, warning, error, or critical
|
||||||
LOG_LEVEL = os.environ.get("LOG_LEVEL", "notice")
|
LOG_LEVEL = os.environ.get("LOG_LEVEL", "notice")
|
||||||
|
|
||||||
|
# Timeout for API-based embedding models
|
||||||
|
# NOTE: does not apply for Google VertexAI, since the python client doesn't
|
||||||
|
# allow us to specify a custom timeout
|
||||||
|
API_BASED_EMBEDDING_TIMEOUT = int(os.environ.get("API_BASED_EMBEDDING_TIMEOUT", "600"))
|
||||||
|
|
||||||
# Only used for OpenAI
|
# Only used for OpenAI
|
||||||
OPENAI_EMBEDDING_TIMEOUT = int(os.environ.get("OPENAI_EMBEDDING_TIMEOUT", "600"))
|
OPENAI_EMBEDDING_TIMEOUT = int(
|
||||||
|
os.environ.get("OPENAI_EMBEDDING_TIMEOUT", API_BASED_EMBEDDING_TIMEOUT)
|
||||||
|
)
|
||||||
|
|
||||||
# Whether or not to strictly enforce token limit for chunking.
|
# Whether or not to strictly enforce token limit for chunking.
|
||||||
STRICT_CHUNK_TOKEN_LIMIT = (
|
STRICT_CHUNK_TOKEN_LIMIT = (
|
||||||
|
@ -61,6 +61,26 @@ def test_cohere_embedding(cohere_embedding_model: EmbeddingModel) -> None:
|
|||||||
_run_embeddings(TOO_LONG_SAMPLE, cohere_embedding_model, 384)
|
_run_embeddings(TOO_LONG_SAMPLE, cohere_embedding_model, 384)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def litellm_embedding_model() -> EmbeddingModel:
|
||||||
|
return EmbeddingModel(
|
||||||
|
server_host="localhost",
|
||||||
|
server_port=9000,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
normalize=True,
|
||||||
|
query_prefix=None,
|
||||||
|
passage_prefix=None,
|
||||||
|
api_key=os.getenv("LITE_LLM_API_KEY"),
|
||||||
|
provider_type=EmbeddingProvider.LITELLM,
|
||||||
|
api_url=os.getenv("LITE_LLM_API_URL"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_litellm_embedding(litellm_embedding_model: EmbeddingModel) -> None:
|
||||||
|
_run_embeddings(VALID_SAMPLE, litellm_embedding_model, 1536)
|
||||||
|
_run_embeddings(TOO_LONG_SAMPLE, litellm_embedding_model, 1536)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def local_nomic_embedding_model() -> EmbeddingModel:
|
def local_nomic_embedding_model() -> EmbeddingModel:
|
||||||
return EmbeddingModel(
|
return EmbeddingModel(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user