mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
Add rate limiting wrapper + add to Document360
This commit is contained in:
parent
64ebaf2dda
commit
08909b40b0
@ -0,0 +1,86 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
class RateLimitTriedTooManyTimesError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class _RateLimitDecorator:
|
||||
"""Builds a generic wrapper/decorator for calls to external APIs that
|
||||
prevents making more than `max_calls` requests per `period`
|
||||
|
||||
Implementation inspired by the `ratelimit` library:
|
||||
https://github.com/tomasbasham/ratelimit.
|
||||
|
||||
NOTE: is not thread safe.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_calls: int,
|
||||
period: float, # in seconds
|
||||
sleep_time: float = 2, # in seconds
|
||||
sleep_backoff: float = 2, # applies exponential backoff
|
||||
max_num_sleep: int = 0,
|
||||
):
|
||||
self.max_calls = max_calls
|
||||
self.period = period
|
||||
self.sleep_time = sleep_time
|
||||
self.sleep_backoff = sleep_backoff
|
||||
self.max_num_sleep = max_num_sleep
|
||||
|
||||
self.call_history: list[float] = []
|
||||
self.curr_calls = 0
|
||||
|
||||
def __call__(self, func: F) -> F:
|
||||
@wraps(func)
|
||||
def wrapped_func(*args: list, **kwargs: dict[str, Any]) -> Any:
|
||||
# cleanup calls which are no longer relevant
|
||||
self._cleanup()
|
||||
|
||||
# check if we've exceeded the rate limit
|
||||
sleep_cnt = 0
|
||||
while len(self.call_history) == self.max_calls:
|
||||
sleep_time = self.sleep_time * (self.sleep_backoff**sleep_cnt)
|
||||
logger.info(
|
||||
f"Rate limit exceeded for function {func.__name__}. "
|
||||
f"Waiting {sleep_time} seconds before retrying."
|
||||
)
|
||||
time.sleep(sleep_time)
|
||||
sleep_cnt += 1
|
||||
if self.max_num_sleep != 0 and sleep_cnt >= self.max_num_sleep:
|
||||
raise RateLimitTriedTooManyTimesError(
|
||||
f"Exceeded '{self.max_num_sleep}' retries for function '{func.__name__}'"
|
||||
)
|
||||
|
||||
self._cleanup()
|
||||
|
||||
# add the current call to the call history
|
||||
self.call_history.append(time.monotonic())
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return cast(F, wrapped_func)
|
||||
|
||||
def _cleanup(self) -> None:
|
||||
curr_time = time.monotonic()
|
||||
time_to_expire_before = curr_time - self.period
|
||||
self.call_history = [
|
||||
call_time
|
||||
for call_time in self.call_history
|
||||
if call_time > time_to_expire_before
|
||||
]
|
||||
|
||||
|
||||
rate_limit_builder = _RateLimitDecorator
|
@ -9,6 +9,10 @@ import requests
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.html_utils import parse_html_page_basic
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
@ -46,6 +50,11 @@ class Document360Connector(LoadConnector, PollConnector):
|
||||
self.portal_id = credentials.get("portal_id")
|
||||
return None
|
||||
|
||||
# rate limiting set based on the enterprise plan: https://apidocs.document360.com/apidocs/rate-limiting
|
||||
# NOTE: retry will handle cases where user is not on enterprise plan - we will just hit the rate limit
|
||||
# and then retry after a period
|
||||
@retry_builder()
|
||||
@rate_limit_builder(max_calls=100, period=60)
|
||||
def _make_request(self, endpoint: str, params: Optional[dict] = None) -> Any:
|
||||
if not self.api_token:
|
||||
raise ConnectorMissingCredentialError("Document360")
|
||||
|
@ -0,0 +1,36 @@
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
|
||||
|
||||
class TestRateLimit(unittest.TestCase):
|
||||
call_cnt = 0
|
||||
|
||||
def test_rate_limit_basic(self) -> None:
|
||||
self.call_cnt = 0
|
||||
|
||||
@rate_limit_builder(max_calls=2, period=5)
|
||||
def func() -> None:
|
||||
self.call_cnt += 1
|
||||
|
||||
start = time.time()
|
||||
|
||||
# make calls that shouldn't be rate-limited
|
||||
func()
|
||||
func()
|
||||
time_to_finish_non_ratelimited = time.time() - start
|
||||
|
||||
# make a call which SHOULD be rate-limited
|
||||
func()
|
||||
time_to_finish_ratelimited = time.time() - start
|
||||
|
||||
self.assertEqual(self.call_cnt, 3)
|
||||
self.assertLess(time_to_finish_non_ratelimited, 1)
|
||||
self.assertGreater(time_to_finish_ratelimited, 5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
x
Reference in New Issue
Block a user