From 3fb68af4056fd80ecb581acbdd57b12a51e4a9fc Mon Sep 17 00:00:00 2001 From: Weves Date: Mon, 29 Apr 2024 20:12:50 -0700 Subject: [PATCH] Address rate limiting for Notion --- .../rate_limit_wrapper.py | 44 +++++++++++++++++++ .../danswer/connectors/notion/connector.py | 14 +++--- 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/backend/danswer/connectors/cross_connector_utils/rate_limit_wrapper.py b/backend/danswer/connectors/cross_connector_utils/rate_limit_wrapper.py index 43baced17..8733ca66e 100644 --- a/backend/danswer/connectors/cross_connector_utils/rate_limit_wrapper.py +++ b/backend/danswer/connectors/cross_connector_utils/rate_limit_wrapper.py @@ -5,6 +5,8 @@ from typing import Any from typing import cast from typing import TypeVar +import requests + from danswer.utils.logger import setup_logger logger = setup_logger() @@ -84,3 +86,45 @@ class _RateLimitDecorator: rate_limit_builder = _RateLimitDecorator + + +"""If you want to allow the external service to tell you when you've hit the rate limit, +use the following instead""" + +R = TypeVar("R", bound=Callable[..., requests.Response]) + + +def wrap_request_to_handle_ratelimiting( + request_fn: R, default_wait_time_sec: int = 30, max_waits: int = 30 +) -> R: + def wrapped_request(*args: list, **kwargs: dict[str, Any]) -> requests.Response: + for _ in range(max_waits): + response = request_fn(*args, **kwargs) + if response.status_code == 429: + try: + wait_time = int( + response.headers.get("Retry-After", default_wait_time_sec) + ) + except ValueError: + wait_time = default_wait_time_sec + + time.sleep(wait_time) + continue + + return response + + raise RateLimitTriedTooManyTimesError(f"Exceeded '{max_waits}' retries") + + return cast(R, wrapped_request) + + +_rate_limited_get = wrap_request_to_handle_ratelimiting(requests.get) +_rate_limited_post = wrap_request_to_handle_ratelimiting(requests.post) + + +class _RateLimitedRequest: + get = _rate_limited_get + post = _rate_limited_post + + +rl_requests = _RateLimitedRequest diff --git a/backend/danswer/connectors/notion/connector.py b/backend/danswer/connectors/notion/connector.py index 2fab13878..2d2897177 100644 --- a/backend/danswer/connectors/notion/connector.py +++ b/backend/danswer/connectors/notion/connector.py @@ -7,12 +7,14 @@ from datetime import timezone from typing import Any from typing import Optional -import requests from retry import retry from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.app_configs import NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP from danswer.configs.constants import DocumentSource +from danswer.connectors.cross_connector_utils.rate_limit_wrapper import ( + rl_requests, +) from danswer.connectors.interfaces import GenerateDocumentsOutput from danswer.connectors.interfaces import LoadConnector from danswer.connectors.interfaces import PollConnector @@ -100,7 +102,7 @@ class NotionConnector(LoadConnector, PollConnector): logger.debug(f"Fetching children of block with ID '{block_id}'") block_url = f"https://api.notion.com/v1/blocks/{block_id}/children" query_params = None if not cursor else {"start_cursor": cursor} - res = requests.get( + res = rl_requests.get( block_url, headers=self.headers, params=query_params, @@ -127,7 +129,7 @@ class NotionConnector(LoadConnector, PollConnector): """Fetch a page from it's ID via the Notion API.""" logger.debug(f"Fetching page for ID '{page_id}'") block_url = f"https://api.notion.com/v1/pages/{page_id}" - res = requests.get( + res = rl_requests.get( block_url, headers=self.headers, timeout=_NOTION_CALL_TIMEOUT, @@ -147,7 +149,7 @@ class NotionConnector(LoadConnector, PollConnector): logger.debug(f"Fetching database for ID '{database_id}'") block_url = f"https://api.notion.com/v1/databases/{database_id}/query" body = None if not cursor else {"start_cursor": cursor} - res = requests.post( + res = rl_requests.post( block_url, headers=self.headers, json=body, @@ -327,7 +329,7 @@ class NotionConnector(LoadConnector, PollConnector): """Search for pages from a Notion database. Includes some small number of retries to handle misc, flakey failures.""" logger.debug(f"Searching for pages in Notion with query_dict: {query_dict}") - res = requests.post( + res = rl_requests.post( "https://api.notion.com/v1/search", headers=self.headers, json=query_dict, @@ -435,6 +437,8 @@ class NotionConnector(LoadConnector, PollConnector): yield from batch_generator(self._read_pages(pages), self.batch_size) if db_res.has_more: query_dict["start_cursor"] = db_res.next_cursor + else: + break else: break