mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-20 01:53:14 +02:00
descrease model server logspam (#4166)
This commit is contained in:
@ -62,6 +62,60 @@ _OPENAI_MAX_INPUT_LEN = 2048
|
|||||||
# Cohere allows up to 96 embeddings in a single embedding calling
|
# Cohere allows up to 96 embeddings in a single embedding calling
|
||||||
_COHERE_MAX_INPUT_LEN = 96
|
_COHERE_MAX_INPUT_LEN = 96
|
||||||
|
|
||||||
|
# Authentication error string constants
|
||||||
|
_AUTH_ERROR_401 = "401"
|
||||||
|
_AUTH_ERROR_UNAUTHORIZED = "unauthorized"
|
||||||
|
_AUTH_ERROR_INVALID_API_KEY = "invalid api key"
|
||||||
|
_AUTH_ERROR_PERMISSION = "permission"
|
||||||
|
|
||||||
|
|
||||||
|
def is_authentication_error(error: Exception) -> bool:
|
||||||
|
"""Check if an exception is related to authentication issues.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error: The exception to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the error appears to be authentication-related
|
||||||
|
"""
|
||||||
|
error_str = str(error).lower()
|
||||||
|
return (
|
||||||
|
_AUTH_ERROR_401 in error_str
|
||||||
|
or _AUTH_ERROR_UNAUTHORIZED in error_str
|
||||||
|
or _AUTH_ERROR_INVALID_API_KEY in error_str
|
||||||
|
or _AUTH_ERROR_PERMISSION in error_str
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def format_embedding_error(
|
||||||
|
error: Exception,
|
||||||
|
service_name: str,
|
||||||
|
model: str | None,
|
||||||
|
provider: EmbeddingProvider,
|
||||||
|
status_code: int | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Format a standardized error string for embedding errors.
|
||||||
|
"""
|
||||||
|
detail = f"Status {status_code}" if status_code else f"{type(error)}"
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: "
|
||||||
|
f"Model: {model} "
|
||||||
|
f"Provider: {provider} "
|
||||||
|
f"Exception: {error}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Custom exception for authentication errors
|
||||||
|
class AuthenticationError(Exception):
|
||||||
|
"""Raised when authentication fails with a provider."""
|
||||||
|
|
||||||
|
def __init__(self, provider: str, message: str = "API key is invalid or expired"):
|
||||||
|
self.provider = provider
|
||||||
|
self.message = message
|
||||||
|
super().__init__(f"{provider} authentication failed: {message}")
|
||||||
|
|
||||||
|
|
||||||
class CloudEmbedding:
|
class CloudEmbedding:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -92,7 +146,7 @@ class CloudEmbedding:
|
|||||||
)
|
)
|
||||||
|
|
||||||
final_embeddings: list[Embedding] = []
|
final_embeddings: list[Embedding] = []
|
||||||
try:
|
|
||||||
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
|
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
|
||||||
response = await client.embeddings.create(
|
response = await client.embeddings.create(
|
||||||
input=text_batch,
|
input=text_batch,
|
||||||
@ -103,20 +157,6 @@ class CloudEmbedding:
|
|||||||
[embedding.embedding for embedding in response.data]
|
[embedding.embedding for embedding in response.data]
|
||||||
)
|
)
|
||||||
return final_embeddings
|
return final_embeddings
|
||||||
except Exception as e:
|
|
||||||
error_string = (
|
|
||||||
f"Exception embedding text with OpenAI - {type(e)}: "
|
|
||||||
f"Model: {model} "
|
|
||||||
f"Provider: {self.provider} "
|
|
||||||
f"Exception: {e}"
|
|
||||||
)
|
|
||||||
logger.error(error_string)
|
|
||||||
|
|
||||||
# only log text when it's not an authentication error.
|
|
||||||
if not isinstance(e, openai.AuthenticationError):
|
|
||||||
logger.debug(f"Exception texts: {texts}")
|
|
||||||
|
|
||||||
raise RuntimeError(error_string)
|
|
||||||
|
|
||||||
async def _embed_cohere(
|
async def _embed_cohere(
|
||||||
self, texts: list[str], model: str | None, embedding_type: str
|
self, texts: list[str], model: str | None, embedding_type: str
|
||||||
@ -155,7 +195,6 @@ class CloudEmbedding:
|
|||||||
input_type=embedding_type,
|
input_type=embedding_type,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response.embeddings
|
return response.embeddings
|
||||||
|
|
||||||
async def _embed_azure(
|
async def _embed_azure(
|
||||||
@ -239,6 +278,7 @@ class CloudEmbedding:
|
|||||||
deployment_name: str | None = None,
|
deployment_name: str | None = None,
|
||||||
reduced_dimension: int | None = None,
|
reduced_dimension: int | None = None,
|
||||||
) -> list[Embedding]:
|
) -> list[Embedding]:
|
||||||
|
try:
|
||||||
if self.provider == EmbeddingProvider.OPENAI:
|
if self.provider == EmbeddingProvider.OPENAI:
|
||||||
return await self._embed_openai(texts, model_name, reduced_dimension)
|
return await self._embed_openai(texts, model_name, reduced_dimension)
|
||||||
elif self.provider == EmbeddingProvider.AZURE:
|
elif self.provider == EmbeddingProvider.AZURE:
|
||||||
@ -255,6 +295,34 @@ class CloudEmbedding:
|
|||||||
return await self._embed_vertex(texts, model_name, embedding_type)
|
return await self._embed_vertex(texts, model_name, embedding_type)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||||
|
except openai.AuthenticationError:
|
||||||
|
raise AuthenticationError(provider="OpenAI")
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
if e.response.status_code == 401:
|
||||||
|
raise AuthenticationError(provider=str(self.provider))
|
||||||
|
|
||||||
|
error_string = format_embedding_error(
|
||||||
|
e,
|
||||||
|
str(self.provider),
|
||||||
|
model_name or deployment_name,
|
||||||
|
self.provider,
|
||||||
|
status_code=e.response.status_code,
|
||||||
|
)
|
||||||
|
logger.error(error_string)
|
||||||
|
logger.debug(f"Exception texts: {texts}")
|
||||||
|
|
||||||
|
raise RuntimeError(error_string)
|
||||||
|
except Exception as e:
|
||||||
|
if is_authentication_error(e):
|
||||||
|
raise AuthenticationError(provider=str(self.provider))
|
||||||
|
|
||||||
|
error_string = format_embedding_error(
|
||||||
|
e, str(self.provider), model_name or deployment_name, self.provider
|
||||||
|
)
|
||||||
|
logger.error(error_string)
|
||||||
|
logger.debug(f"Exception texts: {texts}")
|
||||||
|
|
||||||
|
raise RuntimeError(error_string)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(
|
def create(
|
||||||
@ -569,6 +637,13 @@ async def process_embed_request(
|
|||||||
gpu_type=gpu_type,
|
gpu_type=gpu_type,
|
||||||
)
|
)
|
||||||
return EmbedResponse(embeddings=embeddings)
|
return EmbedResponse(embeddings=embeddings)
|
||||||
|
except AuthenticationError as e:
|
||||||
|
# Handle authentication errors consistently
|
||||||
|
logger.error(f"Authentication error: {e.provider}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail=f"Authentication failed: {e.message}",
|
||||||
|
)
|
||||||
except RateLimitError as e:
|
except RateLimitError as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=429,
|
status_code=429,
|
||||||
|
Reference in New Issue
Block a user