mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-29 11:12:02 +01:00
Add handling for rate limiting (#3280)
This commit is contained in:
parent
634a0b9398
commit
ac448956e9
@ -195,7 +195,7 @@ def connector_external_group_sync_generator_task(
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
Permission sync task that handles document permission syncing for a given connector credential pair
|
||||
Permission sync task that handles external group syncing for a given connector credential pair
|
||||
This task assumes that the task has already been properly fenced
|
||||
"""
|
||||
|
||||
@ -228,9 +228,13 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if ext_group_sync_func is None:
|
||||
raise ValueError(f"No external group sync func found for {source_type}")
|
||||
raise ValueError(
|
||||
f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
|
||||
logger.info(f"Syncing docs for {source_type}")
|
||||
logger.info(
|
||||
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
|
||||
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
|
||||
|
||||
|
@ -0,0 +1,4 @@
|
||||
class ModelServerRateLimitError(Exception):
|
||||
"""
|
||||
Exception raised for rate limiting errors from the model server.
|
||||
"""
|
@ -6,6 +6,9 @@ from typing import Any
|
||||
|
||||
import requests
|
||||
from httpx import HTTPError
|
||||
from requests import JSONDecodeError
|
||||
from requests import RequestException
|
||||
from requests import Response
|
||||
from retry import retry
|
||||
|
||||
from danswer.configs.app_configs import LARGE_CHUNK_RATIO
|
||||
@ -16,6 +19,9 @@ from danswer.configs.model_configs import (
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from danswer.natural_language_processing.exceptions import (
|
||||
ModelServerRateLimitError,
|
||||
)
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.utils.logger import setup_logger
|
||||
@ -99,28 +105,43 @@ class EmbeddingModel:
|
||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||
|
||||
def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
|
||||
def _make_request() -> EmbedResponse:
|
||||
def _make_request() -> Response:
|
||||
response = requests.post(
|
||||
self.embed_server_endpoint, json=embed_request.model_dump()
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
try:
|
||||
error_detail = response.json().get("detail", str(e))
|
||||
except Exception:
|
||||
error_detail = response.text
|
||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||
except requests.RequestException as e:
|
||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||
# signify that this is a rate limit error
|
||||
if response.status_code == 429:
|
||||
raise ModelServerRateLimitError(response.text)
|
||||
|
||||
return EmbedResponse(**response.json())
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
# only perform retries for the non-realtime embedding of passages (e.g. for indexing)
|
||||
final_make_request_func = _make_request
|
||||
|
||||
# if the text type is a passage, add some default
|
||||
# retries + handling for rate limiting
|
||||
if embed_request.text_type == EmbedTextType.PASSAGE:
|
||||
return retry(tries=3, delay=5)(_make_request)()
|
||||
else:
|
||||
return _make_request()
|
||||
final_make_request_func = retry(
|
||||
tries=3,
|
||||
delay=5,
|
||||
exceptions=(RequestException, ValueError, JSONDecodeError),
|
||||
)(final_make_request_func)
|
||||
# use 10 second delay as per Azure suggestion
|
||||
final_make_request_func = retry(
|
||||
tries=10, delay=10, exceptions=ModelServerRateLimitError
|
||||
)(final_make_request_func)
|
||||
|
||||
try:
|
||||
response = final_make_request_func()
|
||||
return EmbedResponse(**response.json())
|
||||
except requests.HTTPError as e:
|
||||
try:
|
||||
error_detail = response.json().get("detail", str(e))
|
||||
except Exception:
|
||||
error_detail = response.text
|
||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||
except requests.RequestException as e:
|
||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||
|
||||
def _batch_encode_texts(
|
||||
self,
|
||||
|
@ -11,6 +11,7 @@ from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
from litellm import embedding
|
||||
from litellm.exceptions import RateLimitError
|
||||
from retry import retry
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
@ -205,28 +206,22 @@ class CloudEmbedding:
|
||||
model_name: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
) -> list[Embedding]:
|
||||
try:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return self._embed_openai(texts, model_name)
|
||||
elif self.provider == EmbeddingProvider.AZURE:
|
||||
return self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
elif self.provider == EmbeddingProvider.LITELLM:
|
||||
return self._embed_litellm_proxy(texts, model_name)
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return self._embed_openai(texts, model_name)
|
||||
elif self.provider == EmbeddingProvider.AZURE:
|
||||
return self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
elif self.provider == EmbeddingProvider.LITELLM:
|
||||
return self._embed_litellm_proxy(texts, model_name)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error embedding text with {self.provider}: {str(e)}",
|
||||
)
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
@ -430,6 +425,11 @@ async def process_embed_request(
|
||||
prefix=prefix,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
except RateLimitError as e:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
exception_detail = f"Error during embedding process:\n{str(e)}"
|
||||
logger.exception(exception_detail)
|
||||
|
@ -7,6 +7,7 @@ from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.model_server_models import EmbeddingProvider
|
||||
|
||||
VALID_SAMPLE = ["hi", "hello my name is bob", "woah there!!!. 😃"]
|
||||
VALID_LONG_SAMPLE = ["hi " * 999]
|
||||
# openai limit is 2048, cohere is supposed to be 96 but in practice that doesn't
|
||||
# seem to be true
|
||||
TOO_LONG_SAMPLE = ["a"] * 2500
|
||||
@ -99,3 +100,42 @@ def local_nomic_embedding_model() -> EmbeddingModel:
|
||||
def test_local_nomic_embedding(local_nomic_embedding_model: EmbeddingModel) -> None:
|
||||
_run_embeddings(VALID_SAMPLE, local_nomic_embedding_model, 768)
|
||||
_run_embeddings(TOO_LONG_SAMPLE, local_nomic_embedding_model, 768)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def azure_embedding_model() -> EmbeddingModel:
|
||||
return EmbeddingModel(
|
||||
server_host="localhost",
|
||||
server_port=9000,
|
||||
model_name="text-embedding-3-large",
|
||||
normalize=True,
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
api_key=os.getenv("AZURE_API_KEY"),
|
||||
provider_type=EmbeddingProvider.AZURE,
|
||||
api_url=os.getenv("AZURE_API_URL"),
|
||||
)
|
||||
|
||||
|
||||
# NOTE (chris): this test doesn't work, and I do not know why
|
||||
# def test_azure_embedding_model_rate_limit(azure_embedding_model: EmbeddingModel):
|
||||
# """NOTE: this test relies on a very low rate limit for the Azure API +
|
||||
# this test only being run once in a 1 minute window"""
|
||||
# # VALID_LONG_SAMPLE is 999 tokens, so the second call should run into rate
|
||||
# # limits assuming the limit is 1000 tokens per minute
|
||||
# result = azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
|
||||
# assert len(result) == 1
|
||||
# assert len(result[0]) == 1536
|
||||
|
||||
# # this should fail
|
||||
# with pytest.raises(ModelServerRateLimitError):
|
||||
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
|
||||
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
|
||||
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
|
||||
|
||||
# # this should succeed, since passage requests retry up to 10 times
|
||||
# start = time.time()
|
||||
# result = azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.PASSAGE)
|
||||
# assert len(result) == 1
|
||||
# assert len(result[0]) == 1536
|
||||
# assert time.time() - start > 30 # make sure we waited, even though we hit rate limits
|
||||
|
Loading…
x
Reference in New Issue
Block a user