Add handling for rate limiting (#3280)

This commit is contained in:
Chris Weaver 2024-11-27 14:22:15 -08:00 committed by GitHub
parent 634a0b9398
commit ac448956e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 109 additions and 40 deletions

View File

@ -195,7 +195,7 @@ def connector_external_group_sync_generator_task(
tenant_id: str | None, tenant_id: str | None,
) -> None: ) -> None:
""" """
Permission sync task that handles document permission syncing for a given connector credential pair Permission sync task that handles external group syncing for a given connector credential pair
This task assumes that the task has already been properly fenced This task assumes that the task has already been properly fenced
""" """
@ -228,9 +228,13 @@ def connector_external_group_sync_generator_task(
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type) ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
if ext_group_sync_func is None: if ext_group_sync_func is None:
raise ValueError(f"No external group sync func found for {source_type}") raise ValueError(
f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
)
logger.info(f"Syncing docs for {source_type}") logger.info(
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
)
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair) external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)

View File

@ -0,0 +1,4 @@
class ModelServerRateLimitError(Exception):
"""
Exception raised for rate limiting errors from the model server.
"""

View File

@ -6,6 +6,9 @@ from typing import Any
import requests import requests
from httpx import HTTPError from httpx import HTTPError
from requests import JSONDecodeError
from requests import RequestException
from requests import Response
from retry import retry from retry import retry
from danswer.configs.app_configs import LARGE_CHUNK_RATIO from danswer.configs.app_configs import LARGE_CHUNK_RATIO
@ -16,6 +19,9 @@ from danswer.configs.model_configs import (
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.models import SearchSettings from danswer.db.models import SearchSettings
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.natural_language_processing.exceptions import (
ModelServerRateLimitError,
)
from danswer.natural_language_processing.utils import get_tokenizer from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
@ -99,28 +105,43 @@ class EmbeddingModel:
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed" self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse: def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
def _make_request() -> EmbedResponse: def _make_request() -> Response:
response = requests.post( response = requests.post(
self.embed_server_endpoint, json=embed_request.model_dump() self.embed_server_endpoint, json=embed_request.model_dump()
) )
try: # signify that this is a rate limit error
response.raise_for_status() if response.status_code == 429:
except requests.HTTPError as e: raise ModelServerRateLimitError(response.text)
try:
error_detail = response.json().get("detail", str(e))
except Exception:
error_detail = response.text
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
return EmbedResponse(**response.json()) response.raise_for_status()
return response
# only perform retries for the non-realtime embedding of passages (e.g. for indexing) final_make_request_func = _make_request
# if the text type is a passage, add some default
# retries + handling for rate limiting
if embed_request.text_type == EmbedTextType.PASSAGE: if embed_request.text_type == EmbedTextType.PASSAGE:
return retry(tries=3, delay=5)(_make_request)() final_make_request_func = retry(
else: tries=3,
return _make_request() delay=5,
exceptions=(RequestException, ValueError, JSONDecodeError),
)(final_make_request_func)
# use 10 second delay as per Azure suggestion
final_make_request_func = retry(
tries=10, delay=10, exceptions=ModelServerRateLimitError
)(final_make_request_func)
try:
response = final_make_request_func()
return EmbedResponse(**response.json())
except requests.HTTPError as e:
try:
error_detail = response.json().get("detail", str(e))
except Exception:
error_detail = response.text
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
def _batch_encode_texts( def _batch_encode_texts(
self, self,

View File

@ -11,6 +11,7 @@ from fastapi import APIRouter
from fastapi import HTTPException from fastapi import HTTPException
from google.oauth2 import service_account # type: ignore from google.oauth2 import service_account # type: ignore
from litellm import embedding from litellm import embedding
from litellm.exceptions import RateLimitError
from retry import retry from retry import retry
from sentence_transformers import CrossEncoder # type: ignore from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore from sentence_transformers import SentenceTransformer # type: ignore
@ -205,28 +206,22 @@ class CloudEmbedding:
model_name: str | None = None, model_name: str | None = None,
deployment_name: str | None = None, deployment_name: str | None = None,
) -> list[Embedding]: ) -> list[Embedding]:
try: if self.provider == EmbeddingProvider.OPENAI:
if self.provider == EmbeddingProvider.OPENAI: return self._embed_openai(texts, model_name)
return self._embed_openai(texts, model_name) elif self.provider == EmbeddingProvider.AZURE:
elif self.provider == EmbeddingProvider.AZURE: return self._embed_azure(texts, f"azure/{deployment_name}")
return self._embed_azure(texts, f"azure/{deployment_name}") elif self.provider == EmbeddingProvider.LITELLM:
elif self.provider == EmbeddingProvider.LITELLM: return self._embed_litellm_proxy(texts, model_name)
return self._embed_litellm_proxy(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type) embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE: if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(texts, model_name, embedding_type) return self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE: elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(texts, model_name, embedding_type) return self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE: elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(texts, model_name, embedding_type) return self._embed_vertex(texts, model_name, embedding_type)
else: else:
raise ValueError(f"Unsupported provider: {self.provider}") raise ValueError(f"Unsupported provider: {self.provider}")
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error embedding text with {self.provider}: {str(e)}",
)
@staticmethod @staticmethod
def create( def create(
@ -430,6 +425,11 @@ async def process_embed_request(
prefix=prefix, prefix=prefix,
) )
return EmbedResponse(embeddings=embeddings) return EmbedResponse(embeddings=embeddings)
except RateLimitError as e:
raise HTTPException(
status_code=429,
detail=str(e),
)
except Exception as e: except Exception as e:
exception_detail = f"Error during embedding process:\n{str(e)}" exception_detail = f"Error during embedding process:\n{str(e)}"
logger.exception(exception_detail) logger.exception(exception_detail)

View File

@ -7,6 +7,7 @@ from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import EmbeddingProvider from shared_configs.model_server_models import EmbeddingProvider
VALID_SAMPLE = ["hi", "hello my name is bob", "woah there!!!. 😃"] VALID_SAMPLE = ["hi", "hello my name is bob", "woah there!!!. 😃"]
VALID_LONG_SAMPLE = ["hi " * 999]
# openai limit is 2048, cohere is supposed to be 96 but in practice that doesn't # openai limit is 2048, cohere is supposed to be 96 but in practice that doesn't
# seem to be true # seem to be true
TOO_LONG_SAMPLE = ["a"] * 2500 TOO_LONG_SAMPLE = ["a"] * 2500
@ -99,3 +100,42 @@ def local_nomic_embedding_model() -> EmbeddingModel:
def test_local_nomic_embedding(local_nomic_embedding_model: EmbeddingModel) -> None: def test_local_nomic_embedding(local_nomic_embedding_model: EmbeddingModel) -> None:
_run_embeddings(VALID_SAMPLE, local_nomic_embedding_model, 768) _run_embeddings(VALID_SAMPLE, local_nomic_embedding_model, 768)
_run_embeddings(TOO_LONG_SAMPLE, local_nomic_embedding_model, 768) _run_embeddings(TOO_LONG_SAMPLE, local_nomic_embedding_model, 768)
@pytest.fixture
def azure_embedding_model() -> EmbeddingModel:
return EmbeddingModel(
server_host="localhost",
server_port=9000,
model_name="text-embedding-3-large",
normalize=True,
query_prefix=None,
passage_prefix=None,
api_key=os.getenv("AZURE_API_KEY"),
provider_type=EmbeddingProvider.AZURE,
api_url=os.getenv("AZURE_API_URL"),
)
# NOTE (chris): this test doesn't work, and I do not know why
# def test_azure_embedding_model_rate_limit(azure_embedding_model: EmbeddingModel):
# """NOTE: this test relies on a very low rate limit for the Azure API +
# this test only being run once in a 1 minute window"""
# # VALID_LONG_SAMPLE is 999 tokens, so the second call should run into rate
# # limits assuming the limit is 1000 tokens per minute
# result = azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
# assert len(result) == 1
# assert len(result[0]) == 1536
# # this should fail
# with pytest.raises(ModelServerRateLimitError):
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
# # this should succeed, since passage requests retry up to 10 times
# start = time.time()
# result = azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.PASSAGE)
# assert len(result) == 1
# assert len(result[0]) == 1536
# assert time.time() - start > 30 # make sure we waited, even though we hit rate limits