Address rate limiting for Notion

This commit is contained in:
Weves 2024-04-29 20:12:50 -07:00 committed by Chris Weaver
parent 5b93e786ad
commit 3fb68af405
2 changed files with 53 additions and 5 deletions

View File

@ -5,6 +5,8 @@ from typing import Any
from typing import cast
from typing import TypeVar
import requests
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -84,3 +86,45 @@ class _RateLimitDecorator:
rate_limit_builder = _RateLimitDecorator
"""If you want to allow the external service to tell you when you've hit the rate limit,
use the following instead"""
R = TypeVar("R", bound=Callable[..., requests.Response])
def wrap_request_to_handle_ratelimiting(
request_fn: R, default_wait_time_sec: int = 30, max_waits: int = 30
) -> R:
def wrapped_request(*args: list, **kwargs: dict[str, Any]) -> requests.Response:
for _ in range(max_waits):
response = request_fn(*args, **kwargs)
if response.status_code == 429:
try:
wait_time = int(
response.headers.get("Retry-After", default_wait_time_sec)
)
except ValueError:
wait_time = default_wait_time_sec
time.sleep(wait_time)
continue
return response
raise RateLimitTriedTooManyTimesError(f"Exceeded '{max_waits}' retries")
return cast(R, wrapped_request)
_rate_limited_get = wrap_request_to_handle_ratelimiting(requests.get)
_rate_limited_post = wrap_request_to_handle_ratelimiting(requests.post)
class _RateLimitedRequest:
get = _rate_limited_get
post = _rate_limited_post
rl_requests = _RateLimitedRequest

View File

@ -7,12 +7,14 @@ from datetime import timezone
from typing import Any
from typing import Optional
import requests
from retry import retry
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rl_requests,
)
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
@ -100,7 +102,7 @@ class NotionConnector(LoadConnector, PollConnector):
logger.debug(f"Fetching children of block with ID '{block_id}'")
block_url = f"https://api.notion.com/v1/blocks/{block_id}/children"
query_params = None if not cursor else {"start_cursor": cursor}
res = requests.get(
res = rl_requests.get(
block_url,
headers=self.headers,
params=query_params,
@ -127,7 +129,7 @@ class NotionConnector(LoadConnector, PollConnector):
"""Fetch a page from it's ID via the Notion API."""
logger.debug(f"Fetching page for ID '{page_id}'")
block_url = f"https://api.notion.com/v1/pages/{page_id}"
res = requests.get(
res = rl_requests.get(
block_url,
headers=self.headers,
timeout=_NOTION_CALL_TIMEOUT,
@ -147,7 +149,7 @@ class NotionConnector(LoadConnector, PollConnector):
logger.debug(f"Fetching database for ID '{database_id}'")
block_url = f"https://api.notion.com/v1/databases/{database_id}/query"
body = None if not cursor else {"start_cursor": cursor}
res = requests.post(
res = rl_requests.post(
block_url,
headers=self.headers,
json=body,
@ -327,7 +329,7 @@ class NotionConnector(LoadConnector, PollConnector):
"""Search for pages from a Notion database. Includes some small number of
retries to handle misc, flakey failures."""
logger.debug(f"Searching for pages in Notion with query_dict: {query_dict}")
res = requests.post(
res = rl_requests.post(
"https://api.notion.com/v1/search",
headers=self.headers,
json=query_dict,
@ -435,6 +437,8 @@ class NotionConnector(LoadConnector, PollConnector):
yield from batch_generator(self._read_pages(pages), self.batch_size)
if db_res.has_more:
query_dict["start_cursor"] = db_res.next_cursor
else:
break
else:
break