mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-12 14:12:53 +02:00
Improve confluence rate limiting
This commit is contained in:
@ -1,10 +1,14 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
from requests import HTTPError
|
||||
from retry import retry
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
@ -18,23 +22,38 @@ class ConfluenceRateLimitError(Exception):
|
||||
|
||||
|
||||
def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
|
||||
@retry(
|
||||
exceptions=ConfluenceRateLimitError,
|
||||
tries=10,
|
||||
delay=1,
|
||||
max_delay=600, # 10 minutes
|
||||
backoff=2,
|
||||
jitter=1,
|
||||
)
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
try:
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
if (
|
||||
e.response.status_code == 429
|
||||
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
|
||||
):
|
||||
raise ConfluenceRateLimitError()
|
||||
raise
|
||||
starting_delay = 5
|
||||
backoff = 2
|
||||
max_delay = 600
|
||||
|
||||
for attempt in range(10):
|
||||
try:
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
if (
|
||||
e.response.status_code == 429
|
||||
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
|
||||
):
|
||||
retry_after = None
|
||||
try:
|
||||
retry_after = int(e.response.headers.get("Retry-After"))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if retry_after:
|
||||
logger.warning(
|
||||
f"Rate limit hit. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
time.sleep(retry_after)
|
||||
else:
|
||||
logger.warning(
|
||||
"Rate limit hit. Retrying with exponential backoff..."
|
||||
)
|
||||
delay = min(starting_delay * (backoff**attempt), max_delay)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
# re-raise, let caller handle
|
||||
raise
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
|
@ -9,7 +9,7 @@ import voyageai # type: ignore
|
||||
from cohere import Client as CohereClient
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from google.oauth2 import service_account
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
from vertexai.language_models import TextEmbeddingInput # type: ignore
|
||||
|
@ -0,0 +1,59 @@
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from requests import HTTPError
|
||||
|
||||
from danswer.connectors.confluence.rate_limit_handler import (
|
||||
make_confluence_call_handle_rate_limit,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_confluence_call() -> Mock:
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code,text,retry_after",
|
||||
[
|
||||
(429, "Rate limit exceeded", "5"),
|
||||
(200, "Rate limit exceeded", None),
|
||||
(429, "Some other error", "5"),
|
||||
],
|
||||
)
|
||||
def test_rate_limit_handling(
|
||||
mock_confluence_call: Mock, status_code: int, text: str, retry_after: str | None
|
||||
) -> None:
|
||||
with patch("time.sleep") as mock_sleep:
|
||||
mock_confluence_call.side_effect = [
|
||||
HTTPError(
|
||||
response=Mock(
|
||||
status_code=status_code,
|
||||
text=text,
|
||||
headers={"Retry-After": retry_after} if retry_after else {},
|
||||
)
|
||||
),
|
||||
] * 2 + ["Success"]
|
||||
|
||||
handled_call = make_confluence_call_handle_rate_limit(mock_confluence_call)
|
||||
result = handled_call()
|
||||
|
||||
assert result == "Success"
|
||||
assert mock_confluence_call.call_count == 3
|
||||
assert mock_sleep.call_count == 2
|
||||
if retry_after:
|
||||
mock_sleep.assert_called_with(int(retry_after))
|
||||
|
||||
|
||||
def test_non_rate_limit_error(mock_confluence_call: Mock) -> None:
|
||||
mock_confluence_call.side_effect = HTTPError(
|
||||
response=Mock(status_code=500, text="Internal Server Error")
|
||||
)
|
||||
|
||||
handled_call = make_confluence_call_handle_rate_limit(mock_confluence_call)
|
||||
|
||||
with pytest.raises(HTTPError):
|
||||
handled_call()
|
||||
|
||||
assert mock_confluence_call.call_count == 1
|
Reference in New Issue
Block a user