Improve confluence rate limiting

This commit is contained in:
Weves
2024-07-14 16:35:24 -07:00
committed by Chris Weaver
parent 1b864a00e4
commit 0d52e99bd4
3 changed files with 97 additions and 19 deletions

View File

@@ -1,10 +1,14 @@
import time
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any
from typing import cast from typing import cast
from typing import TypeVar from typing import TypeVar
from requests import HTTPError from requests import HTTPError
from retry import retry
from danswer.utils.logger import setup_logger
logger = setup_logger()
F = TypeVar("F", bound=Callable[..., Any]) F = TypeVar("F", bound=Callable[..., Any])
@@ -18,23 +22,38 @@ class ConfluenceRateLimitError(Exception):
def make_confluence_call_handle_rate_limit(confluence_call: F) -> F: def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
@retry(
exceptions=ConfluenceRateLimitError,
tries=10,
delay=1,
max_delay=600, # 10 minutes
backoff=2,
jitter=1,
)
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any: def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
try: starting_delay = 5
return confluence_call(*args, **kwargs) backoff = 2
except HTTPError as e: max_delay = 600
if (
e.response.status_code == 429 for attempt in range(10):
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower() try:
): return confluence_call(*args, **kwargs)
raise ConfluenceRateLimitError() except HTTPError as e:
raise if (
e.response.status_code == 429
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
):
retry_after = None
try:
retry_after = int(e.response.headers.get("Retry-After"))
except ValueError:
pass
if retry_after:
logger.warning(
f"Rate limit hit. Retrying after {retry_after} seconds..."
)
time.sleep(retry_after)
else:
logger.warning(
"Rate limit hit. Retrying with exponential backoff..."
)
delay = min(starting_delay * (backoff**attempt), max_delay)
time.sleep(delay)
else:
# re-raise, let caller handle
raise
return cast(F, wrapped_call) return cast(F, wrapped_call)

View File

@@ -9,7 +9,7 @@ import voyageai # type: ignore
from cohere import Client as CohereClient from cohere import Client as CohereClient
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import HTTPException from fastapi import HTTPException
from google.oauth2 import service_account from google.oauth2 import service_account # type: ignore
from sentence_transformers import CrossEncoder # type: ignore from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore from sentence_transformers import SentenceTransformer # type: ignore
from vertexai.language_models import TextEmbeddingInput # type: ignore from vertexai.language_models import TextEmbeddingInput # type: ignore

View File

@@ -0,0 +1,59 @@
from unittest.mock import Mock
from unittest.mock import patch
import pytest
from requests import HTTPError
from danswer.connectors.confluence.rate_limit_handler import (
make_confluence_call_handle_rate_limit,
)
@pytest.fixture
def mock_confluence_call() -> Mock:
return Mock()
@pytest.mark.parametrize(
"status_code,text,retry_after",
[
(429, "Rate limit exceeded", "5"),
(200, "Rate limit exceeded", None),
(429, "Some other error", "5"),
],
)
def test_rate_limit_handling(
mock_confluence_call: Mock, status_code: int, text: str, retry_after: str | None
) -> None:
with patch("time.sleep") as mock_sleep:
mock_confluence_call.side_effect = [
HTTPError(
response=Mock(
status_code=status_code,
text=text,
headers={"Retry-After": retry_after} if retry_after else {},
)
),
] * 2 + ["Success"]
handled_call = make_confluence_call_handle_rate_limit(mock_confluence_call)
result = handled_call()
assert result == "Success"
assert mock_confluence_call.call_count == 3
assert mock_sleep.call_count == 2
if retry_after:
mock_sleep.assert_called_with(int(retry_after))
def test_non_rate_limit_error(mock_confluence_call: Mock) -> None:
mock_confluence_call.side_effect = HTTPError(
response=Mock(status_code=500, text="Internal Server Error")
)
handled_call = make_confluence_call_handle_rate_limit(mock_confluence_call)
with pytest.raises(HTTPError):
handled_call()
assert mock_confluence_call.call_count == 1