mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-28 13:53:28 +02:00
Improve confluence rate limiting
This commit is contained in:
@@ -1,10 +1,14 @@
|
|||||||
|
import time
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
from requests import HTTPError
|
from requests import HTTPError
|
||||||
from retry import retry
|
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
F = TypeVar("F", bound=Callable[..., Any])
|
F = TypeVar("F", bound=Callable[..., Any])
|
||||||
@@ -18,23 +22,38 @@ class ConfluenceRateLimitError(Exception):
|
|||||||
|
|
||||||
|
|
||||||
def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
|
def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
|
||||||
@retry(
|
|
||||||
exceptions=ConfluenceRateLimitError,
|
|
||||||
tries=10,
|
|
||||||
delay=1,
|
|
||||||
max_delay=600, # 10 minutes
|
|
||||||
backoff=2,
|
|
||||||
jitter=1,
|
|
||||||
)
|
|
||||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||||
try:
|
starting_delay = 5
|
||||||
return confluence_call(*args, **kwargs)
|
backoff = 2
|
||||||
except HTTPError as e:
|
max_delay = 600
|
||||||
if (
|
|
||||||
e.response.status_code == 429
|
for attempt in range(10):
|
||||||
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
|
try:
|
||||||
):
|
return confluence_call(*args, **kwargs)
|
||||||
raise ConfluenceRateLimitError()
|
except HTTPError as e:
|
||||||
raise
|
if (
|
||||||
|
e.response.status_code == 429
|
||||||
|
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
|
||||||
|
):
|
||||||
|
retry_after = None
|
||||||
|
try:
|
||||||
|
retry_after = int(e.response.headers.get("Retry-After"))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if retry_after:
|
||||||
|
logger.warning(
|
||||||
|
f"Rate limit hit. Retrying after {retry_after} seconds..."
|
||||||
|
)
|
||||||
|
time.sleep(retry_after)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Rate limit hit. Retrying with exponential backoff..."
|
||||||
|
)
|
||||||
|
delay = min(starting_delay * (backoff**attempt), max_delay)
|
||||||
|
time.sleep(delay)
|
||||||
|
else:
|
||||||
|
# re-raise, let caller handle
|
||||||
|
raise
|
||||||
|
|
||||||
return cast(F, wrapped_call)
|
return cast(F, wrapped_call)
|
||||||
|
@@ -9,7 +9,7 @@ import voyageai # type: ignore
|
|||||||
from cohere import Client as CohereClient
|
from cohere import Client as CohereClient
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from google.oauth2 import service_account
|
from google.oauth2 import service_account # type: ignore
|
||||||
from sentence_transformers import CrossEncoder # type: ignore
|
from sentence_transformers import CrossEncoder # type: ignore
|
||||||
from sentence_transformers import SentenceTransformer # type: ignore
|
from sentence_transformers import SentenceTransformer # type: ignore
|
||||||
from vertexai.language_models import TextEmbeddingInput # type: ignore
|
from vertexai.language_models import TextEmbeddingInput # type: ignore
|
||||||
|
@@ -0,0 +1,59 @@
|
|||||||
|
from unittest.mock import Mock
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from requests import HTTPError
|
||||||
|
|
||||||
|
from danswer.connectors.confluence.rate_limit_handler import (
|
||||||
|
make_confluence_call_handle_rate_limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_confluence_call() -> Mock:
|
||||||
|
return Mock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"status_code,text,retry_after",
|
||||||
|
[
|
||||||
|
(429, "Rate limit exceeded", "5"),
|
||||||
|
(200, "Rate limit exceeded", None),
|
||||||
|
(429, "Some other error", "5"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_rate_limit_handling(
|
||||||
|
mock_confluence_call: Mock, status_code: int, text: str, retry_after: str | None
|
||||||
|
) -> None:
|
||||||
|
with patch("time.sleep") as mock_sleep:
|
||||||
|
mock_confluence_call.side_effect = [
|
||||||
|
HTTPError(
|
||||||
|
response=Mock(
|
||||||
|
status_code=status_code,
|
||||||
|
text=text,
|
||||||
|
headers={"Retry-After": retry_after} if retry_after else {},
|
||||||
|
)
|
||||||
|
),
|
||||||
|
] * 2 + ["Success"]
|
||||||
|
|
||||||
|
handled_call = make_confluence_call_handle_rate_limit(mock_confluence_call)
|
||||||
|
result = handled_call()
|
||||||
|
|
||||||
|
assert result == "Success"
|
||||||
|
assert mock_confluence_call.call_count == 3
|
||||||
|
assert mock_sleep.call_count == 2
|
||||||
|
if retry_after:
|
||||||
|
mock_sleep.assert_called_with(int(retry_after))
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_rate_limit_error(mock_confluence_call: Mock) -> None:
|
||||||
|
mock_confluence_call.side_effect = HTTPError(
|
||||||
|
response=Mock(status_code=500, text="Internal Server Error")
|
||||||
|
)
|
||||||
|
|
||||||
|
handled_call = make_confluence_call_handle_rate_limit(mock_confluence_call)
|
||||||
|
|
||||||
|
with pytest.raises(HTTPError):
|
||||||
|
handled_call()
|
||||||
|
|
||||||
|
assert mock_confluence_call.call_count == 1
|
Reference in New Issue
Block a user