From 0d52e99bd484a8a1b47201a67a5f76b4687ff5aa Mon Sep 17 00:00:00 2001 From: Weves Date: Sun, 14 Jul 2024 16:35:24 -0700 Subject: [PATCH] Improve confluence rate limiting --- .../confluence/rate_limit_handler.py | 55 +++++++++++------ backend/model_server/encoders.py | 2 +- .../confluence/test_rate_limit_handler.py | 59 +++++++++++++++++++ 3 files changed, 97 insertions(+), 19 deletions(-) create mode 100644 backend/tests/unit/danswer/connectors/confluence/test_rate_limit_handler.py diff --git a/backend/danswer/connectors/confluence/rate_limit_handler.py b/backend/danswer/connectors/confluence/rate_limit_handler.py index b9481d6bd46..b0128eabaa8 100644 --- a/backend/danswer/connectors/confluence/rate_limit_handler.py +++ b/backend/danswer/connectors/confluence/rate_limit_handler.py @@ -1,10 +1,14 @@ +import time from collections.abc import Callable from typing import Any from typing import cast from typing import TypeVar from requests import HTTPError -from retry import retry + +from danswer.utils.logger import setup_logger + +logger = setup_logger() F = TypeVar("F", bound=Callable[..., Any]) @@ -18,23 +22,38 @@ class ConfluenceRateLimitError(Exception): def make_confluence_call_handle_rate_limit(confluence_call: F) -> F: - @retry( - exceptions=ConfluenceRateLimitError, - tries=10, - delay=1, - max_delay=600, # 10 minutes - backoff=2, - jitter=1, - ) def wrapped_call(*args: list[Any], **kwargs: Any) -> Any: - try: - return confluence_call(*args, **kwargs) - except HTTPError as e: - if ( - e.response.status_code == 429 - or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower() - ): - raise ConfluenceRateLimitError() - raise + starting_delay = 5 + backoff = 2 + max_delay = 600 + + for attempt in range(10): + try: + return confluence_call(*args, **kwargs) + except HTTPError as e: + if ( + e.response.status_code == 429 + or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower() + ): + retry_after = None + try: + retry_after = int(e.response.headers.get("Retry-After")) + except ValueError: + pass + + if retry_after: + logger.warning( + f"Rate limit hit. Retrying after {retry_after} seconds..." + ) + time.sleep(retry_after) + else: + logger.warning( + "Rate limit hit. Retrying with exponential backoff..." + ) + delay = min(starting_delay * (backoff**attempt), max_delay) + time.sleep(delay) + else: + # re-raise, let caller handle + raise return cast(F, wrapped_call) diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index 1c82698e9f0..abce09eb20b 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -9,7 +9,7 @@ import voyageai # type: ignore from cohere import Client as CohereClient from fastapi import APIRouter from fastapi import HTTPException -from google.oauth2 import service_account +from google.oauth2 import service_account # type: ignore from sentence_transformers import CrossEncoder # type: ignore from sentence_transformers import SentenceTransformer # type: ignore from vertexai.language_models import TextEmbeddingInput # type: ignore diff --git a/backend/tests/unit/danswer/connectors/confluence/test_rate_limit_handler.py b/backend/tests/unit/danswer/connectors/confluence/test_rate_limit_handler.py new file mode 100644 index 00000000000..92bccaa050d --- /dev/null +++ b/backend/tests/unit/danswer/connectors/confluence/test_rate_limit_handler.py @@ -0,0 +1,59 @@ +from unittest.mock import Mock +from unittest.mock import patch + +import pytest +from requests import HTTPError + +from danswer.connectors.confluence.rate_limit_handler import ( + make_confluence_call_handle_rate_limit, +) + + +@pytest.fixture +def mock_confluence_call() -> Mock: + return Mock() + + +@pytest.mark.parametrize( + "status_code,text,retry_after", + [ + (429, "Rate limit exceeded", "5"), + (200, "Rate limit exceeded", None), + (429, "Some other error", "5"), + ], +) +def test_rate_limit_handling( + mock_confluence_call: Mock, status_code: int, text: str, retry_after: str | None +) -> None: + with patch("time.sleep") as mock_sleep: + mock_confluence_call.side_effect = [ + HTTPError( + response=Mock( + status_code=status_code, + text=text, + headers={"Retry-After": retry_after} if retry_after else {}, + ) + ), + ] * 2 + ["Success"] + + handled_call = make_confluence_call_handle_rate_limit(mock_confluence_call) + result = handled_call() + + assert result == "Success" + assert mock_confluence_call.call_count == 3 + assert mock_sleep.call_count == 2 + if retry_after: + mock_sleep.assert_called_with(int(retry_after)) + + +def test_non_rate_limit_error(mock_confluence_call: Mock) -> None: + mock_confluence_call.side_effect = HTTPError( + response=Mock(status_code=500, text="Internal Server Error") + ) + + handled_call = make_confluence_call_handle_rate_limit(mock_confluence_call) + + with pytest.raises(HTTPError): + handled_call() + + assert mock_confluence_call.call_count == 1