mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 20:24:32 +02:00
Add rate limiting wrapper + add to Document360
This commit is contained in:
@@ -0,0 +1,86 @@
|
|||||||
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Any
|
||||||
|
from typing import cast
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
F = TypeVar("F", bound=Callable[..., Any])
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitTriedTooManyTimesError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _RateLimitDecorator:
|
||||||
|
"""Builds a generic wrapper/decorator for calls to external APIs that
|
||||||
|
prevents making more than `max_calls` requests per `period`
|
||||||
|
|
||||||
|
Implementation inspired by the `ratelimit` library:
|
||||||
|
https://github.com/tomasbasham/ratelimit.
|
||||||
|
|
||||||
|
NOTE: is not thread safe.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_calls: int,
|
||||||
|
period: float, # in seconds
|
||||||
|
sleep_time: float = 2, # in seconds
|
||||||
|
sleep_backoff: float = 2, # applies exponential backoff
|
||||||
|
max_num_sleep: int = 0,
|
||||||
|
):
|
||||||
|
self.max_calls = max_calls
|
||||||
|
self.period = period
|
||||||
|
self.sleep_time = sleep_time
|
||||||
|
self.sleep_backoff = sleep_backoff
|
||||||
|
self.max_num_sleep = max_num_sleep
|
||||||
|
|
||||||
|
self.call_history: list[float] = []
|
||||||
|
self.curr_calls = 0
|
||||||
|
|
||||||
|
def __call__(self, func: F) -> F:
|
||||||
|
@wraps(func)
|
||||||
|
def wrapped_func(*args: list, **kwargs: dict[str, Any]) -> Any:
|
||||||
|
# cleanup calls which are no longer relevant
|
||||||
|
self._cleanup()
|
||||||
|
|
||||||
|
# check if we've exceeded the rate limit
|
||||||
|
sleep_cnt = 0
|
||||||
|
while len(self.call_history) == self.max_calls:
|
||||||
|
sleep_time = self.sleep_time * (self.sleep_backoff**sleep_cnt)
|
||||||
|
logger.info(
|
||||||
|
f"Rate limit exceeded for function {func.__name__}. "
|
||||||
|
f"Waiting {sleep_time} seconds before retrying."
|
||||||
|
)
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
sleep_cnt += 1
|
||||||
|
if self.max_num_sleep != 0 and sleep_cnt >= self.max_num_sleep:
|
||||||
|
raise RateLimitTriedTooManyTimesError(
|
||||||
|
f"Exceeded '{self.max_num_sleep}' retries for function '{func.__name__}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._cleanup()
|
||||||
|
|
||||||
|
# add the current call to the call history
|
||||||
|
self.call_history.append(time.monotonic())
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return cast(F, wrapped_func)
|
||||||
|
|
||||||
|
def _cleanup(self) -> None:
|
||||||
|
curr_time = time.monotonic()
|
||||||
|
time_to_expire_before = curr_time - self.period
|
||||||
|
self.call_history = [
|
||||||
|
call_time
|
||||||
|
for call_time in self.call_history
|
||||||
|
if call_time > time_to_expire_before
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
rate_limit_builder = _RateLimitDecorator
|
@@ -9,6 +9,10 @@ import requests
|
|||||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
from danswer.connectors.cross_connector_utils.html_utils import parse_html_page_basic
|
from danswer.connectors.cross_connector_utils.html_utils import parse_html_page_basic
|
||||||
|
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||||
|
rate_limit_builder,
|
||||||
|
)
|
||||||
|
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||||
from danswer.connectors.interfaces import LoadConnector
|
from danswer.connectors.interfaces import LoadConnector
|
||||||
from danswer.connectors.interfaces import PollConnector
|
from danswer.connectors.interfaces import PollConnector
|
||||||
@@ -46,6 +50,11 @@ class Document360Connector(LoadConnector, PollConnector):
|
|||||||
self.portal_id = credentials.get("portal_id")
|
self.portal_id = credentials.get("portal_id")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# rate limiting set based on the enterprise plan: https://apidocs.document360.com/apidocs/rate-limiting
|
||||||
|
# NOTE: retry will handle cases where user is not on enterprise plan - we will just hit the rate limit
|
||||||
|
# and then retry after a period
|
||||||
|
@retry_builder()
|
||||||
|
@rate_limit_builder(max_calls=100, period=60)
|
||||||
def _make_request(self, endpoint: str, params: Optional[dict] = None) -> Any:
|
def _make_request(self, endpoint: str, params: Optional[dict] = None) -> Any:
|
||||||
if not self.api_token:
|
if not self.api_token:
|
||||||
raise ConnectorMissingCredentialError("Document360")
|
raise ConnectorMissingCredentialError("Document360")
|
||||||
|
@@ -0,0 +1,36 @@
|
|||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||||
|
rate_limit_builder,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimit(unittest.TestCase):
|
||||||
|
call_cnt = 0
|
||||||
|
|
||||||
|
def test_rate_limit_basic(self) -> None:
|
||||||
|
self.call_cnt = 0
|
||||||
|
|
||||||
|
@rate_limit_builder(max_calls=2, period=5)
|
||||||
|
def func() -> None:
|
||||||
|
self.call_cnt += 1
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
# make calls that shouldn't be rate-limited
|
||||||
|
func()
|
||||||
|
func()
|
||||||
|
time_to_finish_non_ratelimited = time.time() - start
|
||||||
|
|
||||||
|
# make a call which SHOULD be rate-limited
|
||||||
|
func()
|
||||||
|
time_to_finish_ratelimited = time.time() - start
|
||||||
|
|
||||||
|
self.assertEqual(self.call_cnt, 3)
|
||||||
|
self.assertLess(time_to_finish_non_ratelimited, 1)
|
||||||
|
self.assertGreater(time_to_finish_ratelimited, 5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Reference in New Issue
Block a user