mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-24 23:11:01 +02:00
Add handling for rate limiting (#3280)
This commit is contained in:
parent
634a0b9398
commit
ac448956e9
@ -195,7 +195,7 @@ def connector_external_group_sync_generator_task(
|
|||||||
tenant_id: str | None,
|
tenant_id: str | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Permission sync task that handles document permission syncing for a given connector credential pair
|
Permission sync task that handles external group syncing for a given connector credential pair
|
||||||
This task assumes that the task has already been properly fenced
|
This task assumes that the task has already been properly fenced
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -228,9 +228,13 @@ def connector_external_group_sync_generator_task(
|
|||||||
|
|
||||||
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||||
if ext_group_sync_func is None:
|
if ext_group_sync_func is None:
|
||||||
raise ValueError(f"No external group sync func found for {source_type}")
|
raise ValueError(
|
||||||
|
f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Syncing docs for {source_type}")
|
logger.info(
|
||||||
|
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
|
||||||
|
)
|
||||||
|
|
||||||
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
|
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
|
||||||
|
|
||||||
|
@ -0,0 +1,4 @@
|
|||||||
|
class ModelServerRateLimitError(Exception):
|
||||||
|
"""
|
||||||
|
Exception raised for rate limiting errors from the model server.
|
||||||
|
"""
|
@ -6,6 +6,9 @@ from typing import Any
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
from httpx import HTTPError
|
from httpx import HTTPError
|
||||||
|
from requests import JSONDecodeError
|
||||||
|
from requests import RequestException
|
||||||
|
from requests import Response
|
||||||
from retry import retry
|
from retry import retry
|
||||||
|
|
||||||
from danswer.configs.app_configs import LARGE_CHUNK_RATIO
|
from danswer.configs.app_configs import LARGE_CHUNK_RATIO
|
||||||
@ -16,6 +19,9 @@ from danswer.configs.model_configs import (
|
|||||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||||
from danswer.db.models import SearchSettings
|
from danswer.db.models import SearchSettings
|
||||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||||
|
from danswer.natural_language_processing.exceptions import (
|
||||||
|
ModelServerRateLimitError,
|
||||||
|
)
|
||||||
from danswer.natural_language_processing.utils import get_tokenizer
|
from danswer.natural_language_processing.utils import get_tokenizer
|
||||||
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
@ -99,28 +105,43 @@ class EmbeddingModel:
|
|||||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||||
|
|
||||||
def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
|
def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
|
||||||
def _make_request() -> EmbedResponse:
|
def _make_request() -> Response:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.embed_server_endpoint, json=embed_request.model_dump()
|
self.embed_server_endpoint, json=embed_request.model_dump()
|
||||||
)
|
)
|
||||||
try:
|
# signify that this is a rate limit error
|
||||||
response.raise_for_status()
|
if response.status_code == 429:
|
||||||
except requests.HTTPError as e:
|
raise ModelServerRateLimitError(response.text)
|
||||||
try:
|
|
||||||
error_detail = response.json().get("detail", str(e))
|
|
||||||
except Exception:
|
|
||||||
error_detail = response.text
|
|
||||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
|
||||||
except requests.RequestException as e:
|
|
||||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
|
||||||
|
|
||||||
return EmbedResponse(**response.json())
|
response.raise_for_status()
|
||||||
|
return response
|
||||||
|
|
||||||
# only perform retries for the non-realtime embedding of passages (e.g. for indexing)
|
final_make_request_func = _make_request
|
||||||
|
|
||||||
|
# if the text type is a passage, add some default
|
||||||
|
# retries + handling for rate limiting
|
||||||
if embed_request.text_type == EmbedTextType.PASSAGE:
|
if embed_request.text_type == EmbedTextType.PASSAGE:
|
||||||
return retry(tries=3, delay=5)(_make_request)()
|
final_make_request_func = retry(
|
||||||
else:
|
tries=3,
|
||||||
return _make_request()
|
delay=5,
|
||||||
|
exceptions=(RequestException, ValueError, JSONDecodeError),
|
||||||
|
)(final_make_request_func)
|
||||||
|
# use 10 second delay as per Azure suggestion
|
||||||
|
final_make_request_func = retry(
|
||||||
|
tries=10, delay=10, exceptions=ModelServerRateLimitError
|
||||||
|
)(final_make_request_func)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = final_make_request_func()
|
||||||
|
return EmbedResponse(**response.json())
|
||||||
|
except requests.HTTPError as e:
|
||||||
|
try:
|
||||||
|
error_detail = response.json().get("detail", str(e))
|
||||||
|
except Exception:
|
||||||
|
error_detail = response.text
|
||||||
|
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||||
|
|
||||||
def _batch_encode_texts(
|
def _batch_encode_texts(
|
||||||
self,
|
self,
|
||||||
|
@ -11,6 +11,7 @@ from fastapi import APIRouter
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from google.oauth2 import service_account # type: ignore
|
from google.oauth2 import service_account # type: ignore
|
||||||
from litellm import embedding
|
from litellm import embedding
|
||||||
|
from litellm.exceptions import RateLimitError
|
||||||
from retry import retry
|
from retry import retry
|
||||||
from sentence_transformers import CrossEncoder # type: ignore
|
from sentence_transformers import CrossEncoder # type: ignore
|
||||||
from sentence_transformers import SentenceTransformer # type: ignore
|
from sentence_transformers import SentenceTransformer # type: ignore
|
||||||
@ -205,28 +206,22 @@ class CloudEmbedding:
|
|||||||
model_name: str | None = None,
|
model_name: str | None = None,
|
||||||
deployment_name: str | None = None,
|
deployment_name: str | None = None,
|
||||||
) -> list[Embedding]:
|
) -> list[Embedding]:
|
||||||
try:
|
if self.provider == EmbeddingProvider.OPENAI:
|
||||||
if self.provider == EmbeddingProvider.OPENAI:
|
return self._embed_openai(texts, model_name)
|
||||||
return self._embed_openai(texts, model_name)
|
elif self.provider == EmbeddingProvider.AZURE:
|
||||||
elif self.provider == EmbeddingProvider.AZURE:
|
return self._embed_azure(texts, f"azure/{deployment_name}")
|
||||||
return self._embed_azure(texts, f"azure/{deployment_name}")
|
elif self.provider == EmbeddingProvider.LITELLM:
|
||||||
elif self.provider == EmbeddingProvider.LITELLM:
|
return self._embed_litellm_proxy(texts, model_name)
|
||||||
return self._embed_litellm_proxy(texts, model_name)
|
|
||||||
|
|
||||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||||
if self.provider == EmbeddingProvider.COHERE:
|
if self.provider == EmbeddingProvider.COHERE:
|
||||||
return self._embed_cohere(texts, model_name, embedding_type)
|
return self._embed_cohere(texts, model_name, embedding_type)
|
||||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||||
return self._embed_voyage(texts, model_name, embedding_type)
|
return self._embed_voyage(texts, model_name, embedding_type)
|
||||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||||
return self._embed_vertex(texts, model_name, embedding_type)
|
return self._embed_vertex(texts, model_name, embedding_type)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail=f"Error embedding text with {self.provider}: {str(e)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(
|
def create(
|
||||||
@ -430,6 +425,11 @@ async def process_embed_request(
|
|||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
)
|
)
|
||||||
return EmbedResponse(embeddings=embeddings)
|
return EmbedResponse(embeddings=embeddings)
|
||||||
|
except RateLimitError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
exception_detail = f"Error during embedding process:\n{str(e)}"
|
exception_detail = f"Error during embedding process:\n{str(e)}"
|
||||||
logger.exception(exception_detail)
|
logger.exception(exception_detail)
|
||||||
|
@ -7,6 +7,7 @@ from shared_configs.enums import EmbedTextType
|
|||||||
from shared_configs.model_server_models import EmbeddingProvider
|
from shared_configs.model_server_models import EmbeddingProvider
|
||||||
|
|
||||||
VALID_SAMPLE = ["hi", "hello my name is bob", "woah there!!!. 😃"]
|
VALID_SAMPLE = ["hi", "hello my name is bob", "woah there!!!. 😃"]
|
||||||
|
VALID_LONG_SAMPLE = ["hi " * 999]
|
||||||
# openai limit is 2048, cohere is supposed to be 96 but in practice that doesn't
|
# openai limit is 2048, cohere is supposed to be 96 but in practice that doesn't
|
||||||
# seem to be true
|
# seem to be true
|
||||||
TOO_LONG_SAMPLE = ["a"] * 2500
|
TOO_LONG_SAMPLE = ["a"] * 2500
|
||||||
@ -99,3 +100,42 @@ def local_nomic_embedding_model() -> EmbeddingModel:
|
|||||||
def test_local_nomic_embedding(local_nomic_embedding_model: EmbeddingModel) -> None:
|
def test_local_nomic_embedding(local_nomic_embedding_model: EmbeddingModel) -> None:
|
||||||
_run_embeddings(VALID_SAMPLE, local_nomic_embedding_model, 768)
|
_run_embeddings(VALID_SAMPLE, local_nomic_embedding_model, 768)
|
||||||
_run_embeddings(TOO_LONG_SAMPLE, local_nomic_embedding_model, 768)
|
_run_embeddings(TOO_LONG_SAMPLE, local_nomic_embedding_model, 768)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def azure_embedding_model() -> EmbeddingModel:
|
||||||
|
return EmbeddingModel(
|
||||||
|
server_host="localhost",
|
||||||
|
server_port=9000,
|
||||||
|
model_name="text-embedding-3-large",
|
||||||
|
normalize=True,
|
||||||
|
query_prefix=None,
|
||||||
|
passage_prefix=None,
|
||||||
|
api_key=os.getenv("AZURE_API_KEY"),
|
||||||
|
provider_type=EmbeddingProvider.AZURE,
|
||||||
|
api_url=os.getenv("AZURE_API_URL"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE (chris): this test doesn't work, and I do not know why
|
||||||
|
# def test_azure_embedding_model_rate_limit(azure_embedding_model: EmbeddingModel):
|
||||||
|
# """NOTE: this test relies on a very low rate limit for the Azure API +
|
||||||
|
# this test only being run once in a 1 minute window"""
|
||||||
|
# # VALID_LONG_SAMPLE is 999 tokens, so the second call should run into rate
|
||||||
|
# # limits assuming the limit is 1000 tokens per minute
|
||||||
|
# result = azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
|
||||||
|
# assert len(result) == 1
|
||||||
|
# assert len(result[0]) == 1536
|
||||||
|
|
||||||
|
# # this should fail
|
||||||
|
# with pytest.raises(ModelServerRateLimitError):
|
||||||
|
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
|
||||||
|
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
|
||||||
|
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
|
||||||
|
|
||||||
|
# # this should succeed, since passage requests retry up to 10 times
|
||||||
|
# start = time.time()
|
||||||
|
# result = azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.PASSAGE)
|
||||||
|
# assert len(result) == 1
|
||||||
|
# assert len(result[0]) == 1536
|
||||||
|
# assert time.time() - start > 30 # make sure we waited, even though we hit rate limits
|
||||||
|
Loading…
x
Reference in New Issue
Block a user